From 45512a2b617d7cab3df409ccee966ddb634d8657 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Wed, 12 Feb 2025 15:34:04 +0100 Subject: [PATCH 1/6] fix: add grad spike detection --- changelog.md | 2 + docs/tutorials/training.md | 4 +- docs/tutorials/tuning.md | 2 +- edsnlp/training/trainer.py | 245 +++++++++++++++----- tests/training/dep_parser_config.yml | 2 + tests/training/ner_qlf_diff_bert_config.yml | 1 + tests/training/ner_qlf_same_bert_config.yml | 1 + 7 files changed, 201 insertions(+), 56 deletions(-) diff --git a/changelog.md b/changelog.md index 1eb1dde93..f853a1c6a 100644 --- a/changelog.md +++ b/changelog.md @@ -10,6 +10,7 @@ - `docs/tutorials/tuning.md`: New tutorial for hyperparameter tuning. - Provided a [detailed tutorial](./docs/tutorials/tuning.md) on hyperparameter tuning, covering usage scenarios and configuration options. - `ScheduledOptimizer` (e.g., `@core: "optimizer"`) now supports importing optimizers using their qualified name (e.g., `optim: "torch.optim.Adam"`). +- Added grad spike detection to the `edsnlp.train` script, and per weight layer gradient logging. ### Changed @@ -27,6 +28,7 @@ - Ensure we don't overwrite the RNG of the data reader when calling `stream.shuffle()` with no seed - Raise an error if the batch size in `stream.shuffle(batch_size=...)` is not compatible with the stream - `eds.split` now keeps doc and span attributes in the sub-documents. +- Fixed mini-batch accumulation for multi-task training # v0.15.0 (2024-12-13) diff --git a/docs/tutorials/training.md b/docs/tutorials/training.md index 3ec95f1f9..e7d887364 100644 --- a/docs/tutorials/training.md +++ b/docs/tutorials/training.md @@ -179,7 +179,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman val_data: ${ val_data } max_steps: 2000 validation_interval: ${ train.max_steps//10 } - max_grad_norm: 1.0 + grad_max_norm: 1.0 scorer: ${ scorer } optimizer: ${ optimizer } # Do preprocessing in parallel on 1 worker @@ -284,7 +284,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman val_data=val_data, scorer={"ner": ner_metric}, optimizer=optimizer, - max_grad_norm=1.0, + grad_max_norm=1.0, output_dir="artifacts", # Do preprocessing in parallel on 1 worker num_workers=1, diff --git a/docs/tutorials/tuning.md b/docs/tutorials/tuning.md index 5079043be..01a131242 100644 --- a/docs/tutorials/tuning.md +++ b/docs/tutorials/tuning.md @@ -233,7 +233,7 @@ train: val_data: ${ val_data } max_steps: 400 validation_interval: ${ train.max_steps//2 } - max_grad_norm: 1.0 + grad_max_norm: 1.0 scorer: ${ scorer } optimizer: ${ optimizer } num_workers: 2 diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 896cc707e..d877a3895 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -1,4 +1,5 @@ import json +import math import os import time import warnings @@ -31,10 +32,13 @@ from edsnlp.core.stream import Stream from edsnlp.metrics.ner import NerMetric from edsnlp.metrics.span_attributes import SpanAttributeMetric -from edsnlp.pipes.base import BaseNERComponent, BaseSpanAttributeClassifierComponent +from edsnlp.pipes.base import ( + BaseNERComponent, + BaseSpanAttributeClassifierComponent, +) from edsnlp.utils.batching import BatchSizeArg, stat_batchify from edsnlp.utils.bindings import BINDING_SETTERS -from edsnlp.utils.collections import chain_zip, flatten, ld_to_dl +from edsnlp.utils.collections import chain_zip, flatten, flatten_once, ld_to_dl from edsnlp.utils.span_getters import get_spans from edsnlp.utils.typing import AsList @@ -62,6 +66,10 @@ "goal_wait": 1, "name": r"\1_\2", }, + "grad_norm/__all__": { + "format": "{:.2e}", + "name": "grad_norm", + }, } @@ -208,6 +216,47 @@ def __call__(self, nlp: Pipeline, docs: Iterable[Any]): GenericScorer = Union[GenericScorer, Dict] +def ewm_moments(x, window, adjust=True, bias=False, state=None): + if state is None: + alpha = 2.0 / (window + 1) + decay = 1 - alpha + fresh_weight = 1 if adjust else alpha + mean_val = x + var_val = 0.0 + sum_w = 1.0 + sum_w2 = 1.0 + old_w = 1.0 + return ( + mean_val, + float("nan"), + [decay, fresh_weight, mean_val, var_val, sum_w, sum_w2, old_w], + ) + else: + decay, fresh_weight, mean_val, var_val, sum_w, sum_w2, old_w = state + + sum_w *= decay + sum_w2 *= decay * decay + old_w *= decay + old_m = mean_val + denom = old_w + fresh_weight + mean_val = (old_w * old_m + fresh_weight * x) / denom + d1 = old_m - mean_val + d2 = x - mean_val + var_val = (old_w * (var_val + d1 * d1) + fresh_weight * d2 * d2) / denom + sum_w += fresh_weight + sum_w2 += fresh_weight * fresh_weight + old_w += fresh_weight + + state = [decay, fresh_weight, mean_val, var_val, sum_w, sum_w2, old_w] + + if not bias: + num = sum_w * sum_w + den = num - sum_w2 + var_val = var_val * (num / den) if den > 0 else float("nan") + + return mean_val, var_val, state + + def default_optim( trained_pipes, *, @@ -362,7 +411,10 @@ def train( optimizer: Union[ScheduledOptimizer, torch.optim.Optimizer] = None, validation_interval: Optional[int] = None, checkpoint_interval: Optional[int] = None, - max_grad_norm: float = 5.0, + grad_max_norm: float = 5.0, + grad_ewm_window: int = 100, + grad_dev_policy: Optional[Literal["clip_mean", "clip_threshold", "skip"]] = None, + grad_max_dev: float = 7.0, loss_scales: Dict[str, float] = {}, scorer: GenericScorer = GenericScorer(), num_workers: int = 0, @@ -372,8 +424,9 @@ def train( output_model_dir: Optional[Union[Path, str]] = None, save_model: bool = True, logger: bool = True, - config_meta: Optional[Dict] = None, + log_weight_grads: bool = False, on_validation_callback: Optional[Callable[[Dict], None]] = None, + config_meta: Optional[Dict] = None, **kwargs, ): """ @@ -418,8 +471,26 @@ def train( The number of steps between each evaluation. Defaults to 1/10 of max_steps checkpoint_interval: Optional[int] The number of steps between each model save. Defaults to validation_interval - max_grad_norm: float + grad_max_norm: float The maximum gradient norm + grad_dev_policy: Optional[Literal["clip_mean", "clip_threshold"]] + The policy to apply when a gradient spike is detected, ie. when the + gradient norm is higher than the mean + std * grad_max_dev. Can be: + + - "clip_mean": clip the gradients to the mean gradient norm + - "clip_threshold": clip the gradients to the mean + std * grad_max_dev + - "skip": skip the step + + These do not apply to `grad_max_norm` that is always enforced when it is not + None, since `grad_max_norm` is not adaptive and would most likely prohibit + the model from learning during the early stages of training when gradients are + expected to be high. + grad_ewm_window: int + Approximately how many steps should we look back to compute the average + gradient norm and variance to detect gradient deviation spikes. + grad_max_dev: float + The threshold to apply to detect gradient spikes. A spike is detected + when the value is higher than the mean + variance * threshold. loss_scales: Dict[str, float] The loss scales for each component (useful for multi-task learning) scorer: GenericScorer @@ -455,6 +526,8 @@ def train( spending time dumping the model weights to the disk. logger: bool Whether to log the validation metrics in a rich table. + log_weight_grads: bool + Whether to log the weight gradients during training. on_validation_callback: Optional[Callable[[Dict], None]] A callback function invoked during validation steps to handle custom logic. kwargs: Dict @@ -471,11 +544,18 @@ def train( is_main_process = accelerator.is_main_process device = accelerator.device + if "max_grad_norm" in kwargs: + warnings.warn( + "The 'max_grad_norm' argument is deprecated. Use 'grad_max_norm' instead." + ) + grad_max_norm = kwargs.pop("max_grad_norm") + output_dir = Path(output_dir or Path.cwd() / "artifacts") - output_model_dir = output_model_dir or output_dir / "model-last" + output_model_dir = Path(output_model_dir or output_dir / "model-last") train_metrics_path = output_dir / "train_metrics.json" if is_main_process: os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_model_dir, exist_ok=True) if config_meta is not None: # pragma: no cover print(config_meta["unresolved_config"].to_yaml_str()) config_meta["unresolved_config"].to_disk(output_dir / "train_config.yml") @@ -517,7 +597,8 @@ def train( nlp.post_init(chain_zip([td.data for td in train_data if td.post_init])) for phase_i, pipe_names in enumerate(phases): - trained_pipes = PipeDict({n: nlp.get_pipe(n) for n in pipe_names}, loss_scales) + trained_pipes_local = {n: nlp.get_pipe(n) for n in pipe_names} + trained_pipes = PipeDict(trained_pipes_local, loss_scales) trained_pipes_params = set(trained_pipes.parameters()) phase_training_data = [ td @@ -572,7 +653,9 @@ def train( if hasattr(accel_optim.optimizer, "initialize"): accel_optim.optimizer.initialize() - cumulated_data = defaultdict(lambda: 0.0, count=0) + ewm_state = grad_mean = grad_var = None + default_metrics = dict(count=0, spikes=0) + cumulated_data = defaultdict(lambda: 0, **default_metrics) all_metrics = [] set_seed(seed) with ( @@ -590,34 +673,33 @@ def train( disable=not is_main_process, smoothing=0.3, ): + if ( + save_model + and is_main_process + and (step % checkpoint_interval) == 0 + ): + # torch.save(nlp, output_model_dir / "model.pt") + nlp.to_disk(output_model_dir) if ( is_main_process and step > 0 and (step % validation_interval) == 0 ): scores = scorer(nlp, val_docs) if val_docs else {} - all_metrics.append( - { - "step": step, - "lr": accel_optim.param_groups[0]["lr"], - **cumulated_data, - **scores, - } - ) - cumulated_data.clear() + metrics = { + "step": step, + "lr": accel_optim.param_groups[0]["lr"], + **cumulated_data, + **scores, + } + all_metrics.append(metrics) + cumulated_data = defaultdict(lambda: 0, **default_metrics) train_metrics_path.write_text(json.dumps(all_metrics, indent=2)) if logger: - logger.log_metrics(flatten_dict(all_metrics[-1])) + logger.log_metrics(flatten_dict(metrics)) if on_validation_callback: - on_validation_callback(all_metrics[-1]) - - if ( - save_model - and is_main_process - and (step % checkpoint_interval) == 0 - ): - nlp.to_disk(output_model_dir) + on_validation_callback(metrics) if step == max_steps: break @@ -626,7 +708,7 @@ def train( batches = list(next(iterator)) batches_pipe_names = list( - flatten( + flatten_once( [ [td.pipe_names or pipe_names] * len(b) for td, b in zip(phase_training_data, batches) @@ -636,17 +718,15 @@ def train( batches = list(flatten(batches)) # Synchronize stats between sub-batches across workers - batch_stats = {} + local_batch_stats = {} for b in batches: - fill_flat_stats(b, result=batch_stats) - batch_stats = { - k: sum(v) - for k, v in ld_to_dl(gather_object([batch_stats])).items() - } + fill_flat_stats(b, result=local_batch_stats) + batch_stats = gather_object([local_batch_stats]) + batch_stats = {k: sum(v) for k, v in ld_to_dl(batch_stats).items()} for b in batches: set_flat_stats(b, batch_stats) - res_stats = defaultdict(lambda: 0.0) + local_res_stats = defaultdict(lambda: 0.0) for idx, (batch, batch_pipe_names) in enumerate( zip(batches, batches_pipe_names) ): @@ -658,29 +738,43 @@ def train( if idx < len(batches) - 1 else nullcontext() ) - with cache_ctx, no_sync_ctx: - all_res, loss = trained_pipes( - batch, - enable=batch_pipe_names, + try: + with cache_ctx, no_sync_ctx: + all_res, loss = trained_pipes( + batch, + enable=batch_pipe_names, + ) + for name, res in all_res.items(): + for k, v in res.items(): + if ( + isinstance(v, (float, int)) + or isinstance(v, torch.Tensor) + and v.ndim == 0 + ): + local_res_stats[k] += float(v) + del k, v + del res + del all_res + if ( + isinstance(loss, torch.Tensor) + and loss.requires_grad + ): + accelerator.backward(loss) + except torch.cuda.OutOfMemoryError: + print( + "Out of memory error encountered when processing a " + "batch with the following statistics:" ) - for name, res in all_res.items(): - for k, v in res.items(): - if ( - isinstance(v, (float, int)) - or isinstance(v, torch.Tensor) - and v.ndim == 0 - ): - res_stats[k] += float(v) - del k, v - del res - del all_res - accelerator.backward(loss) + print(local_batch_stats) + raise del loss # Sync output stats after forward such as losses, supports, etc. res_stats = { k: sum(v) - for k, v in ld_to_dl(gather_object([dict(res_stats)])).items() + for k, v in ld_to_dl( + gather_object([dict(local_res_stats)]) + ).items() } if is_main_process: for k, v in batch_stats.items(): @@ -689,8 +783,53 @@ def train( cumulated_data[k] += v del batch_stats, res_stats - accelerator.clip_grad_norm_(grad_params, max_grad_norm) - accel_optim.step() + accelerator.unscale_gradients() + + # Log gradients + if log_weight_grads: + for pipe_name, pipe in trained_pipes_local.items(): + for param_name, param in pipe.named_parameters(): + if param.grad is not None: + cumulated_data[ + f"grad_norm/{pipe_name}/{param_name}" + ] += param.grad.norm().item() + cumulated_data[ + f"param_norm/{pipe_name}/{param_name}" + ] += param.norm().item() + + grad_norm = torch.nn.utils.clip_grad_norm_( + grad_params, grad_max_norm, norm_type=2 + ).item() + + # Detect grad spikes and skip the step if necessary + if grad_dev_policy is not None: + if step > grad_ewm_window and ( + grad_norm - grad_mean + ) > grad_max_dev * math.sqrt(grad_var): + spike = True + cumulated_data["spikes"] += 1 + else: + grad_mean, grad_var, ewm_state = ewm_moments( + grad_norm, grad_ewm_window, state=ewm_state + ) + spike = False + + if spike and grad_dev_policy == "clip_mean": + torch.nn.utils.clip_grad_norm_( + grad_params, grad_mean, norm_type=2 + ) + elif spike and grad_dev_policy == "clip_threshold": + torch.nn.utils.clip_grad_norm_( + grad_params, + grad_mean + math.sqrt(grad_var) * grad_max_dev, + norm_type=2, + ) + + if grad_dev_policy != "skip" or not spike: + accel_optim.step() + + cumulated_data["count"] += 1 + cumulated_data["grad_norm/__all__"] += grad_norm del iterator diff --git a/tests/training/dep_parser_config.yml b/tests/training/dep_parser_config.yml index 19bd9b036..607ba9754 100644 --- a/tests/training/dep_parser_config.yml +++ b/tests/training/dep_parser_config.yml @@ -57,3 +57,5 @@ train: scorer: ${ scorer } num_workers: 0 optimizer: ${ optimizer } + grad_dev_policy: "clip_mean" + log_weight_grads: true diff --git a/tests/training/ner_qlf_diff_bert_config.yml b/tests/training/ner_qlf_diff_bert_config.yml index 2a268fcb0..6b8ce327a 100644 --- a/tests/training/ner_qlf_diff_bert_config.yml +++ b/tests/training/ner_qlf_diff_bert_config.yml @@ -124,3 +124,4 @@ train: scorer: ${ scorer } num_workers: 0 optimizer: ${ optimizer } + grad_dev_policy: "skip" diff --git a/tests/training/ner_qlf_same_bert_config.yml b/tests/training/ner_qlf_same_bert_config.yml index 3a6ebe72f..a429b1d5a 100644 --- a/tests/training/ner_qlf_same_bert_config.yml +++ b/tests/training/ner_qlf_same_bert_config.yml @@ -115,3 +115,4 @@ train: scorer: ${ scorer } num_workers: 0 optimizer: ${ optimizer } + grad_dev_policy: "clip_threshold" From e687df70457abecc10ff62a1a37847a889eb3400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Wed, 19 Feb 2025 18:28:54 +0100 Subject: [PATCH 2/6] fix: in tuning, seed hyperparameter sampler --- edsnlp/tune.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/edsnlp/tune.py b/edsnlp/tune.py index d926d1b49..015bf1d91 100644 --- a/edsnlp/tune.py +++ b/edsnlp/tune.py @@ -14,6 +14,7 @@ from confit.utils.random import set_seed from optuna.importance import FanovaImportanceEvaluator, get_param_importances from optuna.pruners import MedianPruner +from optuna.samplers import TPESampler from pydantic import BaseModel, confloat, conint from edsnlp.training.trainer import GenericScorer, registry, train @@ -287,6 +288,7 @@ def objective(trial): study = optuna.create_study( direction="maximize", pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=2), + sampler=TPESampler(seed=random.randint(0, 2**32 - 1)), ) study.optimize(objective, n_trials=n_trials) return study From 562bbef63c73ec93c86485044576e9f3db040c64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Sun, 16 Feb 2025 23:06:43 +0100 Subject: [PATCH 3/6] feat: added various training loggers using confit.Draft --- .pre-commit-config.yaml | 2 +- changelog.md | 1 + docs/training/index.md | 0 docs/training/loggers.md | 154 ++++++ docs/tutorials/training.md | 45 +- edsnlp/core/pipeline.py | 20 +- edsnlp/core/registries.py | 111 ++-- edsnlp/pipes/base.py | 6 +- edsnlp/training/loggers.py | 537 ++++++++++++++++++++ edsnlp/training/trainer.py | 407 ++++++++------- edsnlp/utils/typing.py | 15 +- mkdocs.yml | 2 + pyproject.toml | 16 +- tests/test_pipeline.py | 6 +- tests/training/ner_qlf_diff_bert_config.yml | 23 +- tests/training/ner_qlf_same_bert_config.yml | 1 + tests/training/test_train.py | 51 +- 17 files changed, 1102 insertions(+), 295 deletions(-) create mode 100644 docs/training/index.md create mode 100644 docs/training/loggers.md create mode 100644 edsnlp/training/loggers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bfaa2cd70..efdf4e9ec 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: # ruff - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.6.4' + rev: 'v0.9.6' hooks: - id: ruff args: ['--config', 'pyproject.toml', '--fix', '--show-fixes'] diff --git a/changelog.md b/changelog.md index f853a1c6a..48f2c35f9 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,7 @@ - Provided a [detailed tutorial](./docs/tutorials/tuning.md) on hyperparameter tuning, covering usage scenarios and configuration options. - `ScheduledOptimizer` (e.g., `@core: "optimizer"`) now supports importing optimizers using their qualified name (e.g., `optim: "torch.optim.Adam"`). - Added grad spike detection to the `edsnlp.train` script, and per weight layer gradient logging. +- Added support for multiple loggers (`tensorboard`, `wandb`, `comet_ml`, `aim`, `mlflow`, `clearml`, `dvclive`, `csv`, `json`, `rich`) in `edsnlp.train` via the `logger` parameter. Default is [`json` and `rich`] for backward compatibility. ### Changed diff --git a/docs/training/index.md b/docs/training/index.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/training/loggers.md b/docs/training/loggers.md new file mode 100644 index 000000000..527a7d02a --- /dev/null +++ b/docs/training/loggers.md @@ -0,0 +1,154 @@ +# Loggers + +When training a model, it is important to keep track of the training process, model performance at different stages, and statistics about the training data over time. This is where loggers come in. Loggers are used to store such information to be able to analyze and visualize it later. + +The EDS-NLP training API (`edsnlp.train`) relies on `accelerate` integration of popular loggers, as well as a few custom loggers. +You can configure loggers in `edsnlp.train` via the `logger` parameter of the `train` function by specifying: + +- a string or a class instance or partially initialized class instance of a logger, e.g. + + === "Via the Python API" + ```{ .python .no-check } + from edsnlp.training.loggers import CSVLogger + from edsnlp.training import train + + logger = CSVLogger.draft() + train(..., logger=logger) + # or train(..., logger="csv") + ``` + + === "Via a config file" + ```yaml + train: + ... + logger: + "@loggers": csv + ... + ``` + + +- or a list of string / logger instances, e.g. + + === "Via the Python API" + ```{ .python .no-check } + from edsnlp.training.loggers import CSVLogger + from edsnlp.training import train + + loggers = ["tensorboard", CSVLogger.draft(...)] + train(..., logger=loggers) + ``` + + === "Via a config file" + ```yaml + train: + ... + logger: + - tensorboard # as a string + - "@loggers": csv # as a (partially) instanciated logger + ... + ``` + +!!! note "Draft objects" + + `edsnlp.train` will provide a default project name and logging dir for loggers that require these parameters, but it is +recommended to set the project name explicitly in the logger configuration. For these loggers, if you don't want to set +the project name yourself, you can either: + + - call `CSVLogger.draft(...)` without the normal init parameters minus the `project_name` or `logging_dir` parameters, + which will cause a `Draft[CSVLogger]` object to be returned if some required parameters are missing + - or use `"@loggers": csv` in the config file, which will also cause a `Draft[CSVLogger]` object to be returned if some required + parameters are missing + + If you do not want a `Draft` object to be returned, call `CSVLogger` directly. + +The supported loggers are listed below. + +### RichLogger {: #edsnlp.training.loggers.RichLogger } + +::: edsnlp.training.loggers.RichLogger.__init__ + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### CSVLogger {: #edsnlp.training.loggers.CSVLogger } + +::: edsnlp.training.loggers.CSVLogger.__init__ + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### JSONLogger {: #edsnlp.training.loggers.JSONLogger } + +::: edsnlp.training.loggers.JSONLogger.__init__ + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### TensorBoardLogger {: #edsnlp.training.loggers.TensorBoardLogger } + +::: edsnlp.training.loggers.TensorBoardLogger + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### AimLogger {: #edsnlp.training.loggers.AimLogger } + +::: edsnlp.training.loggers.AimLogger + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### WandBLogger {: #edsnlp.training.loggers.WandBLogger } + +::: edsnlp.training.loggers.WandBLogger + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### MLflowLogger {: #edsnlp.training.loggers.MLflowLogger } + +::: edsnlp.training.loggers.MLflowLogger + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### CometMLLogger {: #edsnlp.training.loggers.CometMLLogger } + +::: edsnlp.training.loggers.CometMLLogger + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true + +### DVCLiveLogger {: #edsnlp.training.loggers.DVCLiveLogger } + +::: edsnlp.training.loggers.DVCLiveLogger + options: + sections: ["text", "parameters"] + heading_level: 4 + show_bases: false + show_source: false + only_class_level: true diff --git a/docs/tutorials/training.md b/docs/tutorials/training.md index e7d887364..796bb3f19 100644 --- a/docs/tutorials/training.md +++ b/docs/tutorials/training.md @@ -1,4 +1,4 @@ -# Training API +# Training API {: #edsnlp.training.trainer.train } In this tutorial, we'll see how we can quickly train a deep learning model with EDS-NLP using the `edsnlp.train` function. @@ -170,6 +170,30 @@ EDS-NLP supports training models either [from the command line](#from-the-comman - '@factory': eds.standoff_dict2doc span_setter: 'gold_spans' + logger: + - '@loggers': csv + - '@loggers': rich + fields: + step: {} + (.*)loss: + goal: lower_is_better + format: "{:.2e}" + goal_wait: 2 + lr: + format: "{:.2e}" + speed/(.*): + format: "{:.2f}" + name: \1 + "(.*?)/micro/(f|r|p)$": + goal: higher_is_better + format: "{:.2%}" + goal_wait: 1 + name: \1_\2 + grad_norm/__all__: + format: "{:.2e}" + name: grad_norm + # - wandb # enable if you can and want to track with wandb + # 🚀 TRAIN SCRIPT OPTIONS # -> python -m edsnlp.train --config configs/config.yml train: @@ -182,6 +206,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman grad_max_norm: 1.0 scorer: ${ scorer } optimizer: ${ optimizer } + logger: ${ logger } # Do preprocessing in parallel on 1 worker num_workers: 1 # Enable on Mac OS X or if you don't want to use available GPUs @@ -214,6 +239,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman import edsnlp from edsnlp.training import train, ScheduledOptimizer, TrainingData from edsnlp.metrics.ner import NerExactMetric + from edsnlp.training.loggers import CSVLogger, RichLogger, WandbLogger import edsnlp.pipes as eds import torch @@ -270,6 +296,22 @@ EDS-NLP supports training models either [from the command line](#from-the-comman }, ) + # + logger = [ + CSVLogger(), + RichLogger( + fields={ + "step": {}, + "(.*)loss": {"goal": "lower_is_better", "format": "{:.2e}", "goal_wait": 2}, + "lr": {"format": "{:.2e}"}, + "speed/(.*)": {"format": "{:.2f}", "name": "\\1"}, + "(.*?)/micro/(f|r|p)$": {"goal": "higher_is_better", "format": "{:.2%}", "goal_wait": 1, "name": "\\1_\\2"}, + "grad_norm/__all__": {"format": "{:.2e}", "name": "grad_norm"}, + } + ), + # WandBLogger(), # if you can and want to track with Weights & Biases + ] + # 🚀 TRAIN train( nlp=nlp, @@ -286,6 +328,7 @@ EDS-NLP supports training models either [from the command line](#from-the-comman optimizer=optimizer, grad_max_norm=1.0, output_dir="artifacts", + loggers # Do preprocessing in parallel on 1 worker num_workers=1, # Enable on Mac OS X or if you don't want to use available GPUs diff --git a/edsnlp/core/pipeline.py b/edsnlp/core/pipeline.py index d92f99925..b4288d611 100644 --- a/edsnlp/core/pipeline.py +++ b/edsnlp/core/pipeline.py @@ -44,7 +44,7 @@ from spacy.vocab import Vocab, create_vocab from typing_extensions import Literal, Self -from ..core.registries import PIPE_META, CurriedFactory, FactoryMeta, registry +from ..core.registries import PIPE_META, DraftPipe, FactoryMeta, registry from ..utils.collections import ( FrozenDict, FrozenList, @@ -238,9 +238,9 @@ def create_pipe( **(config if config is not None else {}), } ).resolve(registry=registry) - if isinstance(pipe, CurriedFactory): + if isinstance(pipe, DraftPipe): if name is None: - name = signature(pipe.factory).parameters.get("name").default + name = signature(pipe._func).parameters.get("name").default if name is None or name == Parameter.empty: name = factory pipe = pipe.instantiate(nlp=self, path=(name,)) @@ -297,8 +297,8 @@ def add_pipe( raise ValueError( "Can't pass config or name with an instantiated component", ) - if isinstance(factory, CurriedFactory): - name = name or factory.kwargs.get("name") + if isinstance(factory, DraftPipe): + name = name or factory._kwargs.get("name") factory = factory.instantiate(nlp=self, path=(name,)) pipe = factory @@ -585,13 +585,13 @@ def from_config( def _add_pipes( self, pipeline: Sequence[str], - components: Dict[str, CurriedFactory], + components: Dict[str, DraftPipe], exclude: Container[str], enable: Container[str], disable: Container[str], ): try: - components = CurriedFactory.instantiate(components, nlp=self) + components = DraftPipe.instantiate(components, nlp=self) except ConfitValidationError as e: e = ConfitValidationError( e.raw_errors, @@ -1277,9 +1277,9 @@ def load_from_huggingface( owner, model_name = repo_id.split("/") module_name = model_name.replace("-", "_") - assert ( - len(repo_id.split("/")) == 2 - ), "Invalid repo_id format (expected 'owner/repo_name' format)" + assert len(repo_id.split("/")) == 2, ( + "Invalid repo_id format (expected 'owner/repo_name' format)" + ) path = None mtime = None try: diff --git a/edsnlp/core/registries.py b/edsnlp/core/registries.py index 8628b5f4a..e6c22ad0d 100644 --- a/edsnlp/core/registries.py +++ b/edsnlp/core/registries.py @@ -2,13 +2,24 @@ import types from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + TypeVar, + Union, +) from weakref import WeakKeyDictionary import catalogue import spacy from confit import Config, Registry, RegistryCollection, set_default_registry from confit.errors import ConfitValidationError, patch_errors +from confit.registry import Draft from spacy.pipe_analysis import validate_attrs import edsnlp @@ -57,14 +68,16 @@ class FactoryMeta: default_config: Dict -class CurriedFactory: +T = TypeVar("T") + + +class DraftPipe(Draft[T]): def __init__(self, func, kwargs): - self.kwargs = kwargs - self.factory = func + super().__init__(func, kwargs) self.instantiated = None self.error = None - def maybe_nlp(self) -> Union["CurriedFactory", Any]: + def maybe_nlp(self) -> Union["DraftPipe", Any]: """ If the factory requires an nlp argument and the user has explicitly provided it (this is unusual, we usually expect the factory to be @@ -73,11 +86,11 @@ def maybe_nlp(self) -> Union["CurriedFactory", Any]: Returns ------- - Union["CurriedFactory", Any] + Union["PartialFactory", Any] """ from edsnlp.core.pipeline import Pipeline, PipelineProtocol - sig = inspect.signature(self.factory) + sig = inspect.signature(self._func) if ( not ( "nlp" in sig.parameters @@ -86,29 +99,29 @@ def maybe_nlp(self) -> Union["CurriedFactory", Any]: or sig.parameters["nlp"].annotation in (Pipeline, PipelineProtocol) ) ) - or "nlp" in self.kwargs - ) and not self.search_curried_factory(self.kwargs): - return self.factory(**self.kwargs) + or "nlp" in self._kwargs + ) and not self.search_nested_drafts(self._kwargs): + return self._func(**self._kwargs) return self @classmethod - def search_curried_factory(cls, obj): - if isinstance(obj, CurriedFactory): + def search_nested_drafts(cls, obj): + if isinstance(obj, DraftPipe): return obj elif isinstance(obj, dict): for value in obj.values(): - result = cls.search_curried_factory(value) + result = cls.search_nested_drafts(value) if result is not None: return result elif isinstance(obj, (tuple, list, set)): for value in obj: - result = cls.search_curried_factory(value) + result = cls.search_nested_drafts(value) if result is not None: return result return None def instantiate( - obj: Any, + self, nlp: "edsnlp.Pipeline", path: Optional[Sequence[str]] = (), ): @@ -117,51 +130,51 @@ def instantiate( passing in the nlp object and name to factories. Since they can be nested, we need to add them to every factory in the config. """ - if isinstance(obj, CurriedFactory): - if obj.error is not None: - raise obj.error + if isinstance(self, DraftPipe): + if self.error is not None: + raise self.error - if obj.instantiated is not None: - return obj.instantiated + if self.instantiated is not None: + return self.instantiated name = path[0] if len(path) == 1 else None parameters = ( - inspect.signature(obj.factory.__init__).parameters - if isinstance(obj.factory, type) - else inspect.signature(obj.factory).parameters + inspect.signature(self._func.__init__).parameters + if isinstance(self._func, type) + else inspect.signature(self._func).parameters ) kwargs = { - key: CurriedFactory.instantiate( - obj=value, + key: DraftPipe.instantiate( + self=value, nlp=nlp, path=(*path, key), ) - for key, value in obj.kwargs.items() + for key, value in self._kwargs.items() } try: if nlp and "nlp" in parameters: kwargs["nlp"] = nlp if name and "name" in parameters: kwargs["name"] = name - obj.instantiated = obj.factory(**kwargs) + self.instantiated = self._func(**kwargs) except ConfitValidationError as e: - obj.error = e + self.error = e raise ConfitValidationError( patch_errors(e.raw_errors, path, model=e.model), model=e.model, - name=obj.factory.__module__ + "." + obj.factory.__qualname__, + name=self._func.__module__ + "." + self._func.__qualname__, ) # .with_traceback(None) # except Exception as e: # obj.error = e # raise ConfitValidationError([ErrorWrapper(e, path)]) - return obj.instantiated - elif isinstance(obj, dict): + return self.instantiated + elif isinstance(self, dict): instantiated = {} errors = [] - for key, value in obj.items(): + for key, value in self.items(): try: - instantiated[key] = CurriedFactory.instantiate( - obj=value, + instantiated[key] = DraftPipe.instantiate( + self=value, nlp=nlp, path=(*path, key), ) @@ -170,42 +183,31 @@ def instantiate( if errors: raise ConfitValidationError(errors) return instantiated - elif isinstance(obj, (tuple, list)): + elif isinstance(self, (tuple, list)): instantiated = [] errors = [] - for i, value in enumerate(obj): + for i, value in enumerate(self): try: instantiated.append( - CurriedFactory.instantiate(value, nlp, (*path, str(i))) + DraftPipe.instantiate(value, nlp, (*path, str(i))) ) except ConfitValidationError as e: # pragma: no cover errors.append(e.raw_errors) if errors: raise ConfitValidationError(errors) - return type(obj)(instantiated) + return type(self)(instantiated) else: - return obj + return self - def _raise_curried_factory_error(self): + def _raise_draft_error(self): raise TypeError( - f"This component CurriedFactory({self.factory}) has not been instantiated " + f"This {self} component has not been instantiated " f"yet, likely because it was missing an `nlp` pipeline argument. You " f"should either:\n" f"- add it to a pipeline: `pipe = nlp.add_pipe(pipe)`\n" f"- or fill its `nlp` argument: `pipe = factory(nlp=nlp, ...)`" ) - def __call__(self, *args, **kwargs): - self._raise_curried_factory_error() - - def __getattr__(self, name): - if name.startswith("__"): - raise AttributeError(name) - self._raise_curried_factory_error() - - def __repr__(self): - return f"CurriedFactory({self.factory})" - glob = [] @@ -274,7 +276,7 @@ def check_and_return(): if catalogue.check_exists(*registry_path): func = catalogue._get(registry_path) - return lambda **kwargs: CurriedFactory(func, kwargs=kwargs).maybe_nlp() + return lambda **kwargs: DraftPipe(func, kwargs=kwargs).maybe_nlp() # Steps 1 & 2 func = check_and_return() @@ -427,7 +429,7 @@ def invoke(validated_fn, kwargs): @wraps(fn) def curried_registered_fn(**kwargs): - return CurriedFactory(registered_fn, kwargs).maybe_nlp() + return DraftPipe(registered_fn, kwargs).maybe_nlp() return ( curried_registered_fn @@ -453,6 +455,7 @@ class registry(RegistryCollection): core = Registry(("edsnlp", "core"), entry_points=True) optimizers = Registry(("edsnlp", "optimizers"), entry_points=True) schedules = Registry(("edsnlp", "schedules"), entry_points=True) + loggers = Registry(("edsnlp", "loggers"), entry_points=True) set_default_registry(registry) diff --git a/edsnlp/pipes/base.py b/edsnlp/pipes/base.py index e92fc50b8..b17f9ba66 100644 --- a/edsnlp/pipes/base.py +++ b/edsnlp/pipes/base.py @@ -12,7 +12,7 @@ from spacy.tokens import Doc, Span from edsnlp.core import PipelineProtocol -from edsnlp.core.registries import CurriedFactory +from edsnlp.core.registries import DraftPipe from edsnlp.utils.span_getters import ( SpanGetter, # noqa: F401 SpanGetterArg, # noqa: F401 @@ -42,7 +42,7 @@ def __init__(cls, name, bases, dct): def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs): # If this component is missing the nlp argument, we curry it with the - # provided arguments and return a CurriedFactory object. + # provided arguments and return a PartialFactory object. sig = inspect.signature(cls.__init__) try: bound = sig.bind_partial(None, nlp, *args, **kwargs) @@ -52,7 +52,7 @@ def __call__(cls, nlp=inspect.Signature.empty, *args, **kwargs): and sig.parameters["nlp"].default is sig.empty and bound.arguments.get("nlp", sig.empty) is sig.empty ): - return CurriedFactory(cls, bound.arguments) + return DraftPipe(cls, bound.arguments) if nlp is inspect.Signature.empty: bound.arguments.pop("nlp", None) except TypeError: # pragma: no cover diff --git a/edsnlp/training/loggers.py b/edsnlp/training/loggers.py new file mode 100644 index 000000000..94ba11f72 --- /dev/null +++ b/edsnlp/training/loggers.py @@ -0,0 +1,537 @@ +import csv +import json +import os +import warnings +from typing import Any, Dict, Optional, Union + +import accelerate.tracking +from rich_logger import RichTablePrinter + +import edsnlp + + +def flatten_dict(d, path=""): + if not isinstance(d, (list, dict)): + return {path: d} + + if isinstance(d, list): + items = enumerate(d) + else: + items = d.items() + + return { + k: v + for key, val in items + for k, v in flatten_dict(val, f"{path}/{key}" if path else key).items() + } + + +@edsnlp.registry.loggers.register("csv", auto_draft_in_config=True) +class CSVLogger(accelerate.tracking.GeneralTracker): + name = "csv" + requires_logging_directory = True + + @accelerate.tracking.on_main_process + def __init__( + self, + *, + logging_dir: Union[str, os.PathLike], + file_name: str = "metrics.csv", + **kwargs, + ): + """ + A simple CSV-based logger that writes logs to a CSV file. By default, + with `edsnlp.train` the CSV file is located under a local directory + `${CWD}/artifact/metrics.csv`. + + !!! warning "Consistent Keys" + + This logger expects that the `values` dictionary passed to `log` has + consistent keys across all calls. If a new key is encountered in a + subsequent call, it will be ignored and a warning will be issued. + + Parameters + ---------- + logging_dir : str or os.PathLike + Directory in which to store the CSV. + file_name : str, optional + Name of the CSV file. Defaults to "metrics.csv". + """ + super().__init__() + self.logging_dir = logging_dir + os.makedirs(self.logging_dir, exist_ok=True) + + self.file_path = os.path.join(self.logging_dir, file_name) + + self._file = open(self.file_path, mode="a", newline="") + self._writer = csv.writer(self._file) + self._columns = None + self._has_header = False + + @property + def tracker(self): # pragma: no cover + return None + + @accelerate.tracking.on_main_process + def store_init_configuration(self, values: Dict[str, Any]): + pass + + @accelerate.tracking.on_main_process + def log(self, values: Dict[str, Any], step: Optional[int] = None): + """ + Logs `values` to the CSV file, at an optional `step`. + + - If it's the first call, the columns are inferred from the keys in `values` + plus a "step" column if the user provides `step`. + - All subsequent calls must use the same columns. Any missing columns get + written as empty, any new columns generate a warning. + """ + values = flatten_dict(values) + + if self._columns is None: + self._columns = list({**{"step": None}, **values}.keys()) + self._writer.writerow(self._columns) + self._has_header = True + + # Build a row in the order of self._columns + row = [] + for col in self._columns: + if col == "step": + row.append(step if step is not None else "") + else: + if col not in values and col != "step": + row.append("") + else: + row.append(values.get(col, "")) + + for extra_key in values.keys(): + if extra_key not in self._columns: + warnings.warn( + f"CSVTracker: encountered a new field '{extra_key}' that was not in" + f"the field keys of the first logged step. It will not be logged." + ) + + self._writer.writerow(row) + self._file.flush() + + @accelerate.tracking.on_main_process + def finish(self): + self._file.close() + + +@edsnlp.registry.loggers.register("json", auto_draft_in_config=True) +class JSONLogger(accelerate.tracking.GeneralTracker): + name = "json" + requires_logging_directory = True + + @accelerate.tracking.on_main_process + def __init__( + self, + logging_dir: Union[str, os.PathLike], + file_name: str = "metrics.json", + **kwargs, + ): + """ + A simple JSON-based logger that writes logs to a JSON file as a + list of dictionaries. By default, with `edsnlp.train` the JSON file + is located under a local directory `${CWD}/artifact/metrics.json`. + + This method is not recommended for large and frequent logging, as it + re-writes the entire JSON file on every call. Prefer + [`CSVLogger`][edsnlp.training.loggers.CSVLogger] for frequent + and heavy logging. + + Parameters + ---------- + logging_dir : str or os.PathLike + Directory in which to store the JSON file. + file_name : str, optional + Name of the JSON file. Defaults to "metrics.json". + """ + super().__init__() + self.logging_dir = logging_dir + os.makedirs(self.logging_dir, exist_ok=True) + + self._file_path = os.path.join(self.logging_dir, file_name) + self._logs = [] + + @property + def tracker(self): # pragma: no cover + return None + + @accelerate.tracking.on_main_process + def store_init_configuration(self, values: Dict[str, Any]): + pass + + @accelerate.tracking.on_main_process + def log(self, values: Dict[str, Any], step: Optional[int] = None): + """ + Logs `values` along with a `step` (if provided). + + On every call, we: + 1. Append a new record to our in-memory list. + 2. Write out the entire JSON file containing all records. + """ + log_entry = {"step": step, **values} + self._logs.append(log_entry) + + with open(self._file_path, mode="w") as f: + json.dump(self._logs, f, indent=2) + + @accelerate.tracking.on_main_process + def finish(self): + pass + + +@edsnlp.registry.loggers.register("rich") +class RichLogger(accelerate.tracking.GeneralTracker): + DEFAULT_FIELDS = { + "step": {}, + "(.*)loss": { + "goal": "lower_is_better", + "format": "{:.2e}", + "goal_wait": 2, + }, + "lr": {"format": "{:.2e}"}, + "speed/(.*)": {"format": "{:.2f}", r"name": r"\1"}, + "(.*?)/micro/(f|r|p)$": { + "goal": "higher_is_better", + "format": "{:.2%}", + "goal_wait": 1, + "name": r"\1_\2", + }, + "(.*?)/(uas|las)": { + "goal": "higher_is_better", + "format": "{:.2%}", + "goal_wait": 1, + "name": r"\1_\2", + }, + "grad_norm/__all__": { + "format": "{:.2e}", + "name": "grad_norm", + }, + } + + name = "rich" + requires_logging_directory = False + + @accelerate.tracking.on_main_process + def __init__( + self, + run_name: Optional[str] = None, + fields: Dict[str, Union[Dict, bool]] = None, + key: Optional[str] = None, + hijack_tqdm: bool = True, + **kwargs, + ): + """ + A logger that displays logs in a Rich-based table using + [rich-logger](https://github.com/percevalw/rich-logger). + This logger is also available via the loggers registry as `rich`. + + !!! warning "No Disk Logging" + + This logger doesn't save logs to disk. It's meant for displaying + logs in a pretty table during training. If you need to save logs + to disk, consider combining this logger with any other logger. + + Parameters + ---------- + fields: Dict[str, Union[Dict, bool]] + Field descriptors containing goal ("lower_is_better" or "higher_is_better"), + format and display name + The key is a regex that will be used to match the fields to log + Each entry of the dictionary should match the following scheme: + + - key: a regex to match columns + - value: either a Dict or False to hide the column, the dict format is + - name: the name of the column + - goal: "lower_is_better" or "higher_is_better" + + This defaults to a set of metrics and stats that are commonly + logged during EDS-NLP training. + key: Optional[str] + Key to group the logs + hijack_tqdm: bool + Whether to replace the tqdm progress bar with a rich progress bar. + Indeed, rich progress bars integrate better with the rich table. + """ + super().__init__() + + self.run_name = run_name + fields = fields if fields is not None else self.DEFAULT_FIELDS + self.fields = fields or {} + + self.printer = RichTablePrinter(key=key, fields=self.fields) + + if hijack_tqdm: + self.printer.hijack_tqdm() + + @property + def tracker(self): + return self.printer + + @accelerate.tracking.on_main_process + def store_init_configuration(self, values: Dict[str, Any]): + pass + + @accelerate.tracking.on_main_process + def log(self, values: Dict[str, Any], step: Optional[int] = None): + """ + Logs values in the Rich table. If `step` is provided, we include it in the + logged data. + """ + combined = {"step": step, **flatten_dict(values)} + self.printer.log_metrics(combined) + + @accelerate.tracking.on_main_process + def finish(self): + """ + Finalize the table (e.g., stop rendering). + """ + self.printer.finalize() + + +@edsnlp.registry.loggers.register("tensorboard", auto_draft_in_config=True) +class TensorBoardLogger(accelerate.tracking.TensorBoardTracker): + def __init__( + self, + project_name: str, + logging_dir: Optional[Union[str, os.PathLike]] = None, + ): + """ + Logger for [TensorBoard](https://github.com/tensorflow/tensorboard). + This logger is also available via the loggers registry as `tensorboard`. + + Parameters + ---------- + project_name: str + Name of the project. + logging_dir: Union[str, os.PathLike] + Directory in which to store the TensorBoard logs. Logs of different runs + will be stored in `logging_dir/project_name`. + The environment variable `TENSORBOARD_LOGGING_DIR` takes precedence over + this argument. + kwargs: Dict + Additional keyword arguments to pass to `tensorboard.SummaryWriter`. + """ + env_logging_dir = os.environ.get("TENSORBOARD_LOGGING_DIR", None) + if env_logging_dir is not None and logging_dir is not None: # pragma: no cover + warnings.warn( + f"Using the env TENSORBOARD_LOGGING_DIR={env_logging_dir} as the" + f"logging directory for TensorBoard, instead of ${logging_dir}." + ) + logging_dir = env_logging_dir + assert logging_dir is not None, ( + "Please provide a logging directory or set TENSORBOARD_LOGGING_DIR" + ) + super().__init__(project_name, logging_dir) + + def store_init_configuration(self, values: Dict[str, Any]): + values = json.loads(json.dumps(flatten_dict(values), default=str)) + return super().store_init_configuration(values) + + def log(self, values: dict, step: Optional[int] = None, **kwargs): + values = flatten_dict(values) + return super().log(values, step, **kwargs) + + +@edsnlp.registry.loggers.register("aim", auto_draft_in_config=True) +def AimLogger( + project_name: str, + logging_dir: Optional[Union[str, os.PathLike]] = None, + **kwargs, +) -> "accelerate.tracking.AimTracker": # pragma: no cover + """ + Logger for [Aim](https://github.com/aimhubio/aim). + + Parameters + ---------- + project_name: str + Name of the project. + logging_dir: Optional[Union[str, os.PathLike]] + Directory in which to store the Aim logs. + The environment variable `AIM_LOGGING_DIR` takes precedence over this argument. + kwargs: Dict + Additional keyword arguments to pass to the Aim init function. + """ + + env_logging_dir = os.environ.get("AIM_LOGGING_DIR", None) + if env_logging_dir is not None and logging_dir is not None: # pragma: no cover + warnings.warn( + f"Using the env AIM_LOGGING_DIR={env_logging_dir} as the logging directory" + f"for Aim, instead of ${logging_dir}." + ) + logging_dir = env_logging_dir + assert logging_dir is not None, ( + "Please provide a logging directory or set AIM_LOGGING_DIR" + ) + + return accelerate.tracking.AimTracker(project_name, logging_dir, **kwargs) + + +@edsnlp.registry.loggers.register("wandb", auto_draft_in_config=True) +def WandBLogger( + project_name: str, + **kwargs, +) -> "accelerate.tracking.WandBTracker": # pragma: no cover + """ + Logger for [Weights & Biases](https://docs.wandb.ai/quickstart/). + This logger is also available via the loggers registry as `wandb`. + + Parameters + ---------- + project_name: str + Name of the project. This will become the `project` + parameter in `wandb.init`. + kwargs: Dict + Additional keyword arguments to pass to the WandB init function. + + Returns + ------- + accelerate.tracking.WandBTracker + """ + return accelerate.tracking.WandBTracker(project_name, **kwargs) + + +@edsnlp.registry.loggers.register("mlflow", auto_draft_in_config=True) +def MLflowLogger( + project_name: str, + logging_dir: Optional[Union[str, os.PathLike]] = None, + run_id: Optional[str] = None, + tags: Optional[Union[Dict[str, Any], str]] = None, + nested_run: Optional[bool] = False, + run_name: Optional[str] = None, + description: Optional[str] = None, +) -> "accelerate.tracking.MLflowTracker": # pragma: no cover + """ + Logger for + [MLflow](https://mlflow.org/docs/latest/getting-started/intro-quickstart/). + This logger is also available via the loggers registry as `mlflow`. + + Parameters + ---------- + project_name: str + Name of the project. This will become the mlflow experiment name. + logging_dir: Optional[Union[str, os.PathLike]] + Directory in which to store the MLflow logs. + run_id: Optional[str] + If specified, get the run with the specified UUID and log parameters and metrics + under that run. The run’s end time is unset and its status is set to running, + but the run’s other attributes (source_version, source_type, etc.) are not + changed. Environment variable MLFLOW_RUN_ID has priority over this argument. + tags: Optional[Union[Dict[str, Any], str]] + An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to + set as tags on the run. If a run is being resumed, these tags are set on the + resumed run. If a new run is being created, these tags are set on the new run. + Environment variable MLFLOW_TAGS has priority over this argument. + nested_run: Optional[bool] + Controls whether run is nested in parent run. True creates a nested run. + Environment variable MLFLOW_NESTED_RUN has priority over this argument. + run_name: Optional[str] + Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is + unspecified. + description: Optional[str] + An optional string that populates the description box of the run. If a run is + being resumed, the description is set on the resumed run. If a new run is being + created, the description is set on the new run. + + Returns + ------- + accelerate.tracking.MLflowTracker + """ + return accelerate.tracking.MLflowTracker( + project_name, + logging_dir=logging_dir, + run_id=run_id, + tags=tags, + nested_run=nested_run, + run_name=run_name, + description=description, + ) + + +@edsnlp.registry.loggers.register("cometml", auto_draft_in_config=True) +def CometMLLogger( + project_name: str, + **kwargs, +) -> "accelerate.tracking.CometMLTracker": # pragma: no cover + """ + Logger for [CometML](https://www.comet.com/docs/). + This logger is also available via the loggers registry as `cometml`. + + Parameters + ---------- + project_name: str + Name of the project. + kwargs: Dict + Additional keyword arguments to pass to the CometML Experiment + object. + + Returns + ------- + accelerate.tracking.CometMLTracker + """ + return accelerate.tracking.CometMLTracker(project_name, **kwargs) + + +try: + from accelerate.tracking import ClearMLTracker as _ClearMLTracker + + @edsnlp.registry.loggers.register("clearml", auto_draft_in_config=True) + def ClearMLLogger( + project_name: str, + **kwargs, + ) -> "accelerate.tracking.ClearMLTracker": # pragma: no cover + """ + Logger for + [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps/). + This logger is also available via the loggers registry as `clearml`. + + Parameters + ---------- + project_name: str + Name of the experiment. Environment variables `CLEARML_PROJECT` and + `CLEARML_TASK` have priority over this argument. + kwargs: Dict + Additional keyword arguments to pass to the ClearML Task object. + + Returns + ------- + accelerate.tracking.ClearMLTracker + """ + return _ClearMLTracker(project_name, **kwargs) +except ImportError: # pragma: no cover + + def ClearMLLogger(*args, **kwargs): + raise ImportError("ClearMLLogger is not available.") + + +try: + from accelerate.tracking import DVCLiveTracker as _DVCLiveTracker + + @edsnlp.registry.loggers.register("dvclive", auto_draft_in_config=True) + def DVCLiveLogger( + live: Any = None, + **kwargs, + ) -> "accelerate.tracking.DVCLiveTracker": # pragma: no cover + """ + Logger for [DVC Live](https://dvc.org/doc/dvclive). + This logger is also available via the loggers registry as `dvclive`. + + Parameters + ---------- + live: dvclive.Live + An instance of `dvclive.Live` to use for logging. + kwargs: Dict + Additional keyword arguments to pass to the `dvclive.Live` constructor. + + Returns + ------- + accelerate.tracking.DVCLiveTracker + """ + return _DVCLiveTracker(None, live=live, **kwargs) +except ImportError: # pragma: no cover + + def DVCLiveLogger(*args, **kwargs): + raise ImportError("DVCLiveLogger is not available.") diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index d877a3895..05bc935b9 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -1,4 +1,3 @@ -import json import math import os import time @@ -14,20 +13,23 @@ Collection, Dict, Iterable, + List, Optional, Sequence, Union, ) import torch -from accelerate import Accelerator +from accelerate import Accelerator, PartialState +from accelerate.tracking import GeneralTracker from accelerate.utils import gather_object from confit import validate_arguments +from confit.registry import Draft from confit.utils.random import set_seed -from rich_logger import RichTablePrinter from tqdm import tqdm, trange from typing_extensions import Literal +import edsnlp from edsnlp import Pipeline, registry from edsnlp.core.stream import Stream from edsnlp.metrics.ner import NerMetric @@ -73,17 +75,6 @@ } -def flatten_dict(d, path=""): - if not isinstance(d, dict): - return {path: d} - - return { - k: v - for key, val in d.items() - for k, v in flatten_dict(val, f"{path}/{key}" if path else key).items() - } - - def fill_flat_stats(x, result, path=()): if result is None: result = {} @@ -400,6 +391,23 @@ def forward(self, batch, enable: Optional[Sequence[str]] = None): return all_results, loss +def get_logger( + logger: Union[bool, AsList[Union[str, Draft[GeneralTracker], GeneralTracker]]], + project_name, + logging_dir, + **kwargs, +) -> List[GeneralTracker]: + logger = ["rich", "json"] if logger is True else [] if not logger else logger + logger = [ + edsnlp.registry.loggers.get(n)() if isinstance(n, str) else n for n in logger + ] + logger = [ + Draft.instantiate(obj, project_name=project_name, logging_dir=logging_dir) + for obj in logger + ] + return logger + + @validate_arguments(registry=registry) def train( *, @@ -408,7 +416,7 @@ def train( val_data: AsList[Stream] = [], seed: int = 42, max_steps: int = 1000, - optimizer: Union[ScheduledOptimizer, torch.optim.Optimizer] = None, + optimizer: Union[ScheduledOptimizer, Draft[ScheduledOptimizer], torch.optim.Optimizer] = None, # noqa: E501 validation_interval: Optional[int] = None, checkpoint_interval: Optional[int] = None, grad_max_norm: float = 5.0, @@ -423,12 +431,12 @@ def train( output_dir: Union[Path, str] = Path("artifacts"), output_model_dir: Optional[Union[Path, str]] = None, save_model: bool = True, - logger: bool = True, + logger: Union[bool, AsList[Union[str, GeneralTracker, Draft[GeneralTracker]]]] = True, # noqa: E501 log_weight_grads: bool = False, on_validation_callback: Optional[Callable[[Dict], None]] = None, config_meta: Optional[Dict] = None, **kwargs, -): +): # fmt: skip """ Train a pipeline. @@ -456,7 +464,7 @@ def train( The random seed max_steps: int The maximum number of training steps - optimizer: Union[ScheduledOptimizer, torch.optim.Optimizer] + optimizer: Union[ScheduledOptimizer, Draft[ScheduledOptimizer], torch.optim.Optimizer] The optimizer. If None, a default optimizer will be used. ??? note "`ScheduledOptimizer` object/dictionary" @@ -524,8 +532,17 @@ def train( Whether to save the model or not. This can be useful if you are only interested in the metrics, but no the model, and want to avoid spending time dumping the model weights to the disk. - logger: bool - Whether to log the validation metrics in a rich table. + logger: Union[bool, AsList[Union[str, Partial[GeneralTracker], GeneralTracker]]] + The logger to use. Can be a boolean to use the default loggers (rich + and json), a list of logger names, or a list of logger objects. + + You can use huggingface accelerate integrated loggers (`tensorboard`, + `wandb`, `comet_ml`, `aim`, `mlflow`, `clearml`, `dvclive`), or + EDS-NLP simple loggers, or a combination of both: + + - `csv`: logs to a CSV file in `output_dir` (`artifacts/metrics.csv`) + - `json`: logs to a JSON file in `output_dir` (`artifacts/metrics.json`) + - `rich`: logs to a rich table in the terminal log_weight_grads: bool Whether to log the weight gradients during training. on_validation_callback: Optional[Callable[[Dict], None]] @@ -537,9 +554,21 @@ def train( ------- Pipeline The trained pipeline - """ - # Prepare paths - accelerator = Accelerator(cpu=cpu, mixed_precision=mixed_precision) + """ # noqa: E501 + # hack to ensure cpu is set before the accelerator is indirectly initialized + # when creating the trackers + PartialState(cpu=cpu) + project_name = str(Path.cwd() if config_meta is None else config_meta["config_path"][0]) # fmt: skip # noqa: E501 + accelerator = Accelerator( + cpu=cpu, + mixed_precision=mixed_precision, + log_with=get_logger( + logger, + # default project name, the user can override this when creating the logger + project_name=project_name, + logging_dir=output_dir, + ), + ) # accelerator.register_for_checkpointing(dataset) is_main_process = accelerator.is_main_process device = accelerator.device @@ -552,13 +581,18 @@ def train( output_dir = Path(output_dir or Path.cwd() / "artifacts") output_model_dir = Path(output_model_dir or output_dir / "model-last") - train_metrics_path = output_dir / "train_metrics.json" + unresolved_config = None if is_main_process: os.makedirs(output_dir, exist_ok=True) os.makedirs(output_model_dir, exist_ok=True) if config_meta is not None: # pragma: no cover - print(config_meta["unresolved_config"].to_yaml_str()) - config_meta["unresolved_config"].to_disk(output_dir / "train_config.yml") + unresolved_config = config_meta["unresolved_config"] + print(unresolved_config.to_yaml_str()) + unresolved_config.to_disk(output_dir / "train_config.yml") + # TODO: handle config_meta is None + accelerator.init_trackers( + project_name, config=unresolved_config + ) # in theory project name shouldn't be used validation_interval = validation_interval or max_steps // 10 checkpoint_interval = checkpoint_interval or validation_interval @@ -575,7 +609,7 @@ def train( del optimizer if optim is None: warnings.warn( - "No optimizer provided, using default optimizer with default " "parameters" + "No optimizer provided, using default optimizer with default parameters" ) optim = default_optim( [nlp.get_pipe(name) for name in trainable_pipe_names], @@ -586,6 +620,11 @@ def train( if k in kwargs }, ) + optim: torch.nn.Optimizer = Draft.instantiate( + optim, + module=nlp, + total_steps=max_steps, + ) if kwargs: raise ValueError(f"Unknown arguments: {', '.join(kwargs)}") @@ -656,181 +695,161 @@ def train( ewm_state = grad_mean = grad_var = None default_metrics = dict(count=0, spikes=0) cumulated_data = defaultdict(lambda: 0, **default_metrics) - all_metrics = [] set_seed(seed) - with ( - RichTablePrinter(LOGGER_FIELDS, auto_refresh=False) - if is_main_process and logger - else nullcontext() - ) as logger: - # Training loop - for step in trange( - max_steps + 1, - desc="Training model", - leave=True, - mininterval=5.0, - total=max_steps, - disable=not is_main_process, - smoothing=0.3, + # Training loop + for step in trange( + max_steps + 1, + desc="Training model", + leave=True, + mininterval=5.0, + total=max_steps, + disable=not is_main_process, + smoothing=0.3, + ): + if save_model and is_main_process and (step % checkpoint_interval) == 0: + nlp.to_disk(output_model_dir) + if is_main_process and step > 0 and (step % validation_interval) == 0: + scores = scorer(nlp, val_docs) if val_docs else {} + metrics = { + "step": step, + "lr": accel_optim.param_groups[0]["lr"], + **cumulated_data, + **scores, + } + cumulated_data = defaultdict(lambda: 0, **default_metrics) + accelerator.log(metrics, step=step) + + if on_validation_callback: + on_validation_callback(metrics) + + if step == max_steps: + break + + accel_optim.zero_grad() + + batches = list(next(iterator)) + batches_pipe_names = list( + flatten_once( + [ + [td.pipe_names or pipe_names] * len(b) + for td, b in zip(phase_training_data, batches) + ] + ) + ) + batches = list(flatten(batches)) + + # Synchronize stats between sub-batches across workers + local_batch_stats = {} + for b in batches: + fill_flat_stats(b, result=local_batch_stats) + batch_stats = gather_object([local_batch_stats]) + batch_stats = {k: sum(v) for k, v in ld_to_dl(batch_stats).items()} + for b in batches: + set_flat_stats(b, batch_stats) + + local_res_stats = defaultdict(lambda: 0.0) + for idx, (batch, batch_pipe_names) in enumerate( + zip(batches, batches_pipe_names) ): - if ( - save_model - and is_main_process - and (step % checkpoint_interval) == 0 - ): - # torch.save(nlp, output_model_dir / "model.pt") - nlp.to_disk(output_model_dir) - if ( - is_main_process - and step > 0 - and (step % validation_interval) == 0 - ): - scores = scorer(nlp, val_docs) if val_docs else {} - metrics = { - "step": step, - "lr": accel_optim.param_groups[0]["lr"], - **cumulated_data, - **scores, - } - all_metrics.append(metrics) - cumulated_data = defaultdict(lambda: 0, **default_metrics) - train_metrics_path.write_text(json.dumps(all_metrics, indent=2)) - if logger: - logger.log_metrics(flatten_dict(metrics)) - - if on_validation_callback: - on_validation_callback(metrics) - - if step == max_steps: - break - - accel_optim.zero_grad() - - batches = list(next(iterator)) - batches_pipe_names = list( - flatten_once( - [ - [td.pipe_names or pipe_names] * len(b) - for td, b in zip(phase_training_data, batches) - ] - ) + cache_ctx = ( + nlp.cache() if len(batch_pipe_names) > 1 else nullcontext() + ) + no_sync_ctx = ( + accelerator.no_sync(trained_pipes) + if idx < len(batches) - 1 + else nullcontext() ) - batches = list(flatten(batches)) - - # Synchronize stats between sub-batches across workers - local_batch_stats = {} - for b in batches: - fill_flat_stats(b, result=local_batch_stats) - batch_stats = gather_object([local_batch_stats]) - batch_stats = {k: sum(v) for k, v in ld_to_dl(batch_stats).items()} - for b in batches: - set_flat_stats(b, batch_stats) - - local_res_stats = defaultdict(lambda: 0.0) - for idx, (batch, batch_pipe_names) in enumerate( - zip(batches, batches_pipe_names) - ): - cache_ctx = ( - nlp.cache() if len(batch_pipe_names) > 1 else nullcontext() + try: + with cache_ctx, no_sync_ctx: + all_res, loss = trained_pipes( + batch, + enable=batch_pipe_names, + ) + for name, res in all_res.items(): + for k, v in res.items(): + if ( + isinstance(v, (float, int)) + or isinstance(v, torch.Tensor) + and v.ndim == 0 + ): + local_res_stats[k] += float(v) + del k, v + del res + del all_res + if isinstance(loss, torch.Tensor) and loss.requires_grad: + accelerator.backward(loss) + except torch.cuda.OutOfMemoryError: # pragma: no cover + print( + "Out of memory error encountered when processing a " + "batch with the following statistics:" + ) + print(local_batch_stats) + raise + del loss + + # Sync output stats after forward such as losses, supports, etc. + res_stats = { + k: sum(v) + for k, v in ld_to_dl(gather_object([dict(local_res_stats)])).items() + } + if is_main_process: + for k, v in batch_stats.items(): + cumulated_data[k] += v + for k, v in res_stats.items(): + cumulated_data[k] += v + + del batch_stats, res_stats + accelerator.unscale_gradients() + + # Log gradients + if log_weight_grads: + for pipe_name, pipe in trained_pipes_local.items(): + for param_name, param in pipe.named_parameters(): + if param.grad is not None: + cumulated_data[ + f"grad_norm/{pipe_name}/{param_name}" + ] += param.grad.norm().item() + cumulated_data[ + f"param_norm/{pipe_name}/{param_name}" + ] += param.norm().item() + + grad_norm = torch.nn.utils.clip_grad_norm_( + grad_params, grad_max_norm, norm_type=2 + ).item() + + # Detect grad spikes and skip the step if necessary + if grad_dev_policy is not None: + if step > grad_ewm_window and ( + grad_norm - grad_mean + ) > grad_max_dev * math.sqrt(grad_var): + spike = True + cumulated_data["spikes"] += 1 + else: + grad_mean, grad_var, ewm_state = ewm_moments( + grad_norm, grad_ewm_window, state=ewm_state ) - no_sync_ctx = ( - accelerator.no_sync(trained_pipes) - if idx < len(batches) - 1 - else nullcontext() + spike = False + + if spike and grad_dev_policy == "clip_mean": + torch.nn.utils.clip_grad_norm_( + grad_params, grad_mean, norm_type=2 + ) + elif spike and grad_dev_policy == "clip_threshold": + torch.nn.utils.clip_grad_norm_( + grad_params, + grad_mean + math.sqrt(grad_var) * grad_max_dev, + norm_type=2, ) - try: - with cache_ctx, no_sync_ctx: - all_res, loss = trained_pipes( - batch, - enable=batch_pipe_names, - ) - for name, res in all_res.items(): - for k, v in res.items(): - if ( - isinstance(v, (float, int)) - or isinstance(v, torch.Tensor) - and v.ndim == 0 - ): - local_res_stats[k] += float(v) - del k, v - del res - del all_res - if ( - isinstance(loss, torch.Tensor) - and loss.requires_grad - ): - accelerator.backward(loss) - except torch.cuda.OutOfMemoryError: - print( - "Out of memory error encountered when processing a " - "batch with the following statistics:" - ) - print(local_batch_stats) - raise - del loss - - # Sync output stats after forward such as losses, supports, etc. - res_stats = { - k: sum(v) - for k, v in ld_to_dl( - gather_object([dict(local_res_stats)]) - ).items() - } - if is_main_process: - for k, v in batch_stats.items(): - cumulated_data[k] += v - for k, v in res_stats.items(): - cumulated_data[k] += v - - del batch_stats, res_stats - accelerator.unscale_gradients() - - # Log gradients - if log_weight_grads: - for pipe_name, pipe in trained_pipes_local.items(): - for param_name, param in pipe.named_parameters(): - if param.grad is not None: - cumulated_data[ - f"grad_norm/{pipe_name}/{param_name}" - ] += param.grad.norm().item() - cumulated_data[ - f"param_norm/{pipe_name}/{param_name}" - ] += param.norm().item() - - grad_norm = torch.nn.utils.clip_grad_norm_( - grad_params, grad_max_norm, norm_type=2 - ).item() - - # Detect grad spikes and skip the step if necessary - if grad_dev_policy is not None: - if step > grad_ewm_window and ( - grad_norm - grad_mean - ) > grad_max_dev * math.sqrt(grad_var): - spike = True - cumulated_data["spikes"] += 1 - else: - grad_mean, grad_var, ewm_state = ewm_moments( - grad_norm, grad_ewm_window, state=ewm_state - ) - spike = False - if spike and grad_dev_policy == "clip_mean": - torch.nn.utils.clip_grad_norm_( - grad_params, grad_mean, norm_type=2 - ) - elif spike and grad_dev_policy == "clip_threshold": - torch.nn.utils.clip_grad_norm_( - grad_params, - grad_mean + math.sqrt(grad_var) * grad_max_dev, - norm_type=2, - ) + if grad_dev_policy != "skip" or not spike: + accel_optim.step() - if grad_dev_policy != "skip" or not spike: - accel_optim.step() + cumulated_data["count"] += 1 + cumulated_data["grad_norm/__all__"] += grad_norm - cumulated_data["count"] += 1 - cumulated_data["grad_norm/__all__"] += grad_norm + del iterator - del iterator + # Should we put this in a finally block? + accelerator.end_training() return nlp diff --git a/edsnlp/utils/typing.py b/edsnlp/utils/typing.py index 5dd675c21..bc6f36fe9 100644 --- a/edsnlp/utils/typing.py +++ b/edsnlp/utils/typing.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Generic, List, TypeVar, Union import pydantic +from confit import Validatable from confit.errors import patch_errors T = TypeVar("T") @@ -11,19 +12,7 @@ from pydantic_core import core_schema -class Validated: - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def __get_pydantic_core_schema__(cls, source, handler): - return core_schema.chain_schema( - [ - core_schema.no_info_plain_validator_function(v) - for v in cls.__get_validators__() - ] - ) +Validated = Validatable class MetaAsList(type): diff --git a/mkdocs.yml b/mkdocs.yml index 46dfd19e3..db045f299 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -142,6 +142,8 @@ nav: - data/polars.md - data/spark.md - data/converters.md + - Training: + - training/loggers.md - Concepts: - concepts/pipeline.md - concepts/torch-component.md diff --git a/pyproject.toml b/pyproject.toml index 311780ac9..59577e3b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ dependencies = [ "regex", "spacy>=3.2,<3.8", "thinc<8.2.5", # we don't need thinc but spacy depdends on it 8.2.5 cause binary issues - "confit>=0.7.3", + #"confit>=0.7.3", + "confit @ git+https://github.com/aphp/confit@allow-partial", "tqdm", "umls-downloader>=0.1.1", "numpy>=1.15.0,<1.23.2; python_version<'3.8'", @@ -87,6 +88,7 @@ docs = [ dev = [ "edsnlp[dev-no-ml]", "edsnlp[ml]", + "tensorboard", "optuna>=4.0.0", "plotly>=5.18.0", # required by optuna viz ] @@ -290,6 +292,18 @@ where = ["."] "standoff" = "edsnlp.data:write_standoff" "brat" = "edsnlp.data:write_brat" # alias for standoff +[project.entry-points."edsnlp_loggers"] +"csv" = "edsnlp.training.loggers:CSVLogger" +"json" = "edsnlp.training.loggers:JSONLogger" +"rich" = "edsnlp.training.loggers:RichLogger" +"tensorboard" = "edsnlp.training.loggers:TensorBoardLogger" +"aim" = "edsnlp.training.loggers:AimLogger" +"wandb" = "edsnlp.training.loggers:WandBLogger" +"clearml" = "edsnlp.training.loggers:ClearMLLogger" +"mlflow" = "edsnlp.training.loggers:MLflowLogger" +"cometml" = "edsnlp.training.loggers:CometMLLogger" +"dvclive" = "edsnlp.training.loggers:DVCLiveLogger" + [project.entry-points."spacy_misc"] "eds.span_context_getter" = "edsnlp.utils.span_getters:make_span_context_getter" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 8e289fea2..119e1bdca 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -12,7 +12,7 @@ import edsnlp import edsnlp.pipes as eds from edsnlp import Pipeline, registry -from edsnlp.core.registries import CurriedFactory +from edsnlp.core.registries import DraftPipe from edsnlp.pipes.base import BaseComponent try: @@ -378,9 +378,9 @@ def test_curried_nlp_pipe(): nlp.add_pipe(eds.sections(), name="sections") pipe = CustomComponent() - assert isinstance(pipe, CurriedFactory) + assert isinstance(pipe, DraftPipe) err = ( - f"This component CurriedFactory({pipe.factory}) has not been instantiated " + f"This Draft[{pipe._func.__qualname__}] component has not been instantiated " f"yet, likely because it was missing an `nlp` pipeline argument. You should " f"either:\n" "- add it to a pipeline: `pipe = nlp.add_pipe(pipe)`\n" diff --git a/tests/training/ner_qlf_diff_bert_config.yml b/tests/training/ner_qlf_diff_bert_config.yml index 6b8ce327a..148f52b48 100644 --- a/tests/training/ner_qlf_diff_bert_config.yml +++ b/tests/training/ner_qlf_diff_bert_config.yml @@ -25,8 +25,8 @@ nlp: embedding: '@factory': eds.transformer model: hf-internal-testing/tiny-bert - window: 128 - stride: 96 + window: 256 + stride: 128 new_tokens: [ [ "(?:\\n\\s*)*\\n", "⏎" ] ] qualifier: @@ -37,15 +37,15 @@ nlp: embedding: '@factory': eds.span_pooler - embedding: # ${ nlp.components.ner.embedding } - '@factory': eds.text_cnn - kernel_sizes: [ 3 ] - - embedding: - '@factory': eds.transformer - model: hf-internal-testing/tiny-bert - window: 128 - stride: 96 + embedding: ${ nlp.components.ner.embedding } +# '@factory': eds.text_cnn +# kernel_sizes: [ 3 ] +# +# embedding: +# '@factory': eds.transformer +# model: hf-internal-testing/tiny-bert +# window: 128 +# stride: 96 # 📈 SCORERS scorer: @@ -125,3 +125,4 @@ train: num_workers: 0 optimizer: ${ optimizer } grad_dev_policy: "skip" + logger: ["csv", "json", "rich"] diff --git a/tests/training/ner_qlf_same_bert_config.yml b/tests/training/ner_qlf_same_bert_config.yml index a429b1d5a..337fd2aa1 100644 --- a/tests/training/ner_qlf_same_bert_config.yml +++ b/tests/training/ner_qlf_same_bert_config.yml @@ -116,3 +116,4 @@ train: num_workers: 0 optimizer: ${ optimizer } grad_dev_policy: "clip_threshold" + logger: ["tensorboard", "rich"] diff --git a/tests/training/test_train.py b/tests/training/test_train.py index b5eca9f07..8a1e4f03b 100644 --- a/tests/training/test_train.py +++ b/tests/training/test_train.py @@ -1,4 +1,5 @@ # ruff:noqa:E402 +import os.path import pytest @@ -28,6 +29,7 @@ from edsnlp.core.registries import registry from edsnlp.data.converters import AttributesMappingArg, get_current_tokenizer from edsnlp.metrics.dep_parsing import DependencyParsingMetric +from edsnlp.training.loggers import CSVLogger from edsnlp.training.optimizer import LinearSchedule, ScheduledOptimizer from edsnlp.training.trainer import GenericScorer, train from edsnlp.utils.span_getters import SpanSetterArg, set_spans @@ -97,7 +99,16 @@ def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path): config = Config.from_disk("ner_qlf_diff_bert_config.yml") shutil.rmtree(tmp_path, ignore_errors=True) kwargs = Config.resolve(config["train"], registry=registry, root=config) - nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + nlp = train( + **kwargs, + output_dir=tmp_path, + cpu=True, + config_meta={ + "config_path": "dep_parser_config.yml", + "resolved_config": kwargs, + "unresolved_config": config, + }, + ) scorer = GenericScorer(**kwargs["scorer"]) val_data = kwargs["val_data"] last_scores = scorer(nlp, val_data) @@ -108,13 +119,26 @@ def test_ner_qualif_train_diff_bert(run_in_test_dir, tmp_path): assert last_scores["ner"]["micro"]["f"] > 0.4 assert last_scores["qual"]["micro"]["f"] > 0.4 + # Ensure we saved the metrics + assert os.path.exists(tmp_path / "metrics.json") + assert os.path.exists(tmp_path / "metrics.csv") + def test_ner_qualif_train_same_bert(run_in_test_dir, tmp_path): set_seed(42) config = Config.from_disk("ner_qlf_same_bert_config.yml") shutil.rmtree(tmp_path, ignore_errors=True) kwargs = Config.resolve(config["train"], registry=registry, root=config) - nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + nlp = train( + **kwargs, + output_dir=tmp_path, + cpu=True, + config_meta={ + "config_path": "dep_parser_config.yml", + "resolved_config": kwargs, + "unresolved_config": config, + }, + ) scorer = GenericScorer(**kwargs["scorer"]) val_data = kwargs["val_data"] last_scores = scorer(nlp, val_data) @@ -131,7 +155,16 @@ def test_qualif_train(run_in_test_dir, tmp_path): config = Config.from_disk("qlf_config.yml") shutil.rmtree(tmp_path, ignore_errors=True) kwargs = Config.resolve(config["train"], registry=registry, root=config) - nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + nlp = train( + **kwargs, + output_dir=tmp_path, + cpu=True, + config_meta={ + "config_path": "dep_parser_config.yml", + "resolved_config": kwargs, + "unresolved_config": config, + }, + ) scorer = GenericScorer(**kwargs["scorer"]) val_data = kwargs["val_data"] last_scores = scorer(nlp, val_data) @@ -147,7 +180,17 @@ def test_dep_parser_train(run_in_test_dir, tmp_path): config = Config.from_disk("dep_parser_config.yml") shutil.rmtree(tmp_path, ignore_errors=True) kwargs = Config.resolve(config["train"], registry=registry, root=config) - nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + nlp = train( + **kwargs, + logger=CSVLogger.draft(), + output_dir=tmp_path, + cpu=True, + config_meta={ + "config_path": "dep_parser_config.yml", + "resolved_config": kwargs, + "unresolved_config": config, + }, + ) scorer = GenericScorer(**kwargs["scorer"]) val_data = list(kwargs["val_data"]) last_scores = scorer(nlp, val_data) From afe2eb3f3a26d203e8f70c964bf4848367a5dea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Sun, 16 Feb 2025 23:09:10 +0100 Subject: [PATCH 4/6] docs: make more registered function clickable in the docs --- changelog.md | 1 + docs/scripts/clickable_snippets.py | 72 +++++++++++++++++++++++------- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/changelog.md b/changelog.md index 48f2c35f9..d36a262c2 100644 --- a/changelog.md +++ b/changelog.md @@ -12,6 +12,7 @@ - `ScheduledOptimizer` (e.g., `@core: "optimizer"`) now supports importing optimizers using their qualified name (e.g., `optim: "torch.optim.Adam"`). - Added grad spike detection to the `edsnlp.train` script, and per weight layer gradient logging. - Added support for multiple loggers (`tensorboard`, `wandb`, `comet_ml`, `aim`, `mlflow`, `clearml`, `dvclive`, `csv`, `json`, `rich`) in `edsnlp.train` via the `logger` parameter. Default is [`json` and `rich`] for backward compatibility. +- Added clickable snippets in the documentation for more registered functions ### Changed diff --git a/docs/scripts/clickable_snippets.py b/docs/scripts/clickable_snippets.py index 2b901448a..98b9440c8 100644 --- a/docs/scripts/clickable_snippets.py +++ b/docs/scripts/clickable_snippets.py @@ -1,7 +1,7 @@ # Based on https://github.com/darwindarak/mdx_bib import os -import re from bisect import bisect_right +from collections import defaultdict from typing import Tuple import jedi @@ -22,11 +22,7 @@ from bs4 import BeautifulSoup -BRACKET_RE = re.compile(r"\[([^\[]+)\]") -CITE_RE = re.compile(r"@([\w_:-]+)") -DEF_RE = re.compile(r"\A {0,3}\[@([\w_:-]+)\]:\s*(.*)") -INDENT_RE = re.compile(r"\A\t| {4}(.*)") - +# Used to match href in HTML to replace with a relative path HREF_REGEX = ( r"(?<=<\s*(?:a[^>]*href|img[^>]*src)=)" r'(?:"([^"]*)"|\'([^\']*)|[ ]*([^ =>]*)(?![a-z]+=))' @@ -42,6 +38,15 @@ (?![a-zA-Z0-9._-]) """ +REGISTRY_REGEX = r"""(?x) +(?]*>(?:"|&\#39;|")@([a-zA-Z0-9._-]*)(?:"|&\#39;|")<\/span>\s* +]*>:<\/span>\s* +]*>\s*<\/span>\s* +]*>(?:"|&\#39;|")?([a-zA-Z0-9._-]*)(?:"|&\#39;|")?<\/span> +(?![a-zA-Z0-9._-]) +""" + CITATION_RE = r"(\[@(?:[\w_:-]+)(?: *, *@(?:[\w_:-]+))*\])" @@ -62,11 +67,15 @@ def on_config(self, config: MkDocsConfig): plugin.load_config(plugin_config) @classmethod - def get_ep_namespace(cls, ep, namespace): + def get_ep_namespace(cls, ep, namespace=None): if hasattr(ep, "select"): - return ep.select(group=namespace) + return ep.select(group=namespace) if namespace else list(ep._all) else: # dict - return ep.get(namespace, []) + return ( + ep.get(namespace, []) + if namespace + else (x for g in ep.values() for x in g) + ) @mkdocs.plugins.event_priority(-1000) def on_post_page( @@ -94,18 +103,26 @@ def on_post_page( autorefs: AutorefsPlugin = config["plugins"]["autorefs"] ep = entry_points() page_url = os.path.join("/", page.file.url) - spacy_factories_entry_points = { + factories_entry_points = { ep.name: ep.value for ep in ( *self.get_ep_namespace(ep, "spacy_factories"), *self.get_ep_namespace(ep, "edsnlp_factories"), ) } - - def replace_component(match): - full_group = match.group(0) + all_entry_points = defaultdict(dict) + for ep in self.get_ep_namespace(ep): + if ep.group.startswith("edsnlp_") or ep.group.startswith("spacy_"): + group = ep.group.split("_", 1)[1] + all_entry_points[group][ep.name] = ep.value + + # This method is meant for replacing any component that + # appears in a "eds.component" format, no matter if it is + # preceded by a "@factory" or not. + def replace_factory_component(match): + full_match = match.group(0) name = "eds." + match.group(1) - ep = spacy_factories_entry_points.get(name) + ep = factories_entry_points.get(name) preceding = output[match.start(0) - 50 : match.start(0)] if ep is not None and "DEFAULT:" not in preceding: try: @@ -114,7 +131,27 @@ def replace_component(match): pass else: return f"{name}" - return full_group + return full_match + + # This method is meant for replacing any component that + # appears in a "@registry": "component" format + def replace_any_registry_component(match): + full_match = match.group(0) + group = match.group(1) + name = match.group(2) + ep = all_entry_points[group].get(name) + preceding = output[match.start(0) - 50 : match.start(0)] + if ep is not None and "DEFAULT:" not in preceding: + try: + url = autorefs.get_item_url(ep.replace(":", ".")) + except KeyError: + pass + else: + repl = f'{name}' + before = full_match[: match.start(2) - match.start(0)] + after = full_match[match.end(2) - match.start(0) :] + return before + repl + after + return full_match def replace_link(match): relative_url = url = match.group(1) or match.group(2) or match.group(3) @@ -122,8 +159,9 @@ def replace_link(match): relative_url = os.path.relpath(url, page_url) return f'"{relative_url}"' - output = regex.sub(PIPE_REGEX, replace_component, output) - output = regex.sub(HTML_PIPE_REGEX, replace_component, output) + output = regex.sub(PIPE_REGEX, replace_factory_component, output) + output = regex.sub(HTML_PIPE_REGEX, replace_factory_component, output) + output = regex.sub(REGISTRY_REGEX, replace_any_registry_component, output) all_snippets = "" all_offsets = [] From ca645a550bff4869802ff320a9f8f08cdfddf1ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 6 Sep 2024 02:56:50 +0200 Subject: [PATCH 5/6] feat: new eds.relation_detector_ffn trainable component --- changelog.md | 2 + edsnlp/data/converters.py | 159 +++++--- edsnlp/data/standoff.py | 34 +- edsnlp/extensions.py | 5 +- edsnlp/metrics/relations.py | 142 +++++++ edsnlp/pipes/__init__.py | 1 + edsnlp/pipes/base.py | 36 ++ .../embeddings/span_pooler/span_pooler.py | 4 +- .../relation_detector_ffn/__init__.py | 1 + .../relation_detector_ffn/factory.py | 9 + .../relation_detector_ffn.py | 383 ++++++++++++++++++ edsnlp/training/trainer.py | 31 ++ edsnlp/utils/span_getters.py | 26 +- pyproject.toml | 22 +- tests/data/test_converters.py | 9 + tests/training/dataset_2/sample-1.ann | 6 + tests/training/dataset_2/sample-1.txt | 2 + tests/training/dataset_2/sample-2.ann | 4 + tests/training/dataset_2/sample-2.txt | 1 + tests/training/rel_config.cfg | 79 ++++ tests/training/rel_config.yml | 90 ++++ tests/training/test_train.py | 16 + 22 files changed, 968 insertions(+), 94 deletions(-) create mode 100644 edsnlp/metrics/relations.py create mode 100644 edsnlp/pipes/trainable/relation_detector_ffn/__init__.py create mode 100644 edsnlp/pipes/trainable/relation_detector_ffn/factory.py create mode 100644 edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py create mode 100644 tests/training/dataset_2/sample-1.ann create mode 100644 tests/training/dataset_2/sample-1.txt create mode 100644 tests/training/dataset_2/sample-2.ann create mode 100644 tests/training/dataset_2/sample-2.txt create mode 100644 tests/training/rel_config.cfg create mode 100644 tests/training/rel_config.yml diff --git a/changelog.md b/changelog.md index d36a262c2..337eb8089 100644 --- a/changelog.md +++ b/changelog.md @@ -13,6 +13,8 @@ - Added grad spike detection to the `edsnlp.train` script, and per weight layer gradient logging. - Added support for multiple loggers (`tensorboard`, `wandb`, `comet_ml`, `aim`, `mlflow`, `clearml`, `dvclive`, `csv`, `json`, `rich`) in `edsnlp.train` via the `logger` parameter. Default is [`json` and `rich`] for backward compatibility. - Added clickable snippets in the documentation for more registered functions +- New trainable `eds.relation_detector_ffn` component to detect relations between entities. These relations are stored in each entity: `head._.rel[relation_label] = [tail1, tail2, ...]`. +- Load "Status" annotator notes as `status` dict attribute ### Changed diff --git a/edsnlp/data/converters.py b/edsnlp/data/converters.py index 4cca047b0..c8d8ca1d8 100644 --- a/edsnlp/data/converters.py +++ b/edsnlp/data/converters.py @@ -240,76 +240,101 @@ def __init__( def __call__(self, obj, tokenizer=None): # tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer - tok = tokenizer or self.tokenizer or get_current_tokenizer() - doc = tok(obj["text"] or "") - doc._.note_id = obj.get("doc_id", obj.get(FILENAME)) - - spans = [] - - for dst in ( - *(() if self.span_attributes is None else self.span_attributes.values()), - *self.default_attributes, - ): - if not Span.has_extension(dst): - Span.set_extension(dst, default=None) - - for ent in obj.get("entities") or (): - fragments = ( - [ - { - "begin": min(f["begin"] for f in ent["fragments"]), - "end": max(f["end"] for f in ent["fragments"]), - } - ] - if not self.split_fragments - else ent["fragments"] - ) - for fragment in fragments: - span = doc.char_span( - fragment["begin"], - fragment["end"], - label=ent["label"], - alignment_mode="expand", - ) - attributes = ( - {a["label"]: a["value"] for a in ent["attributes"]} - if isinstance(ent["attributes"], list) - else ent["attributes"] + note_id = obj.get("doc_id", obj.get(FILENAME)) + try: + tok = tokenizer or self.tokenizer or get_current_tokenizer() + doc = tok(obj["text"] or "") + doc._.note_id = note_id + + entities = {} + spans = [] + + for dst in ( + *( + () + if self.span_attributes is None + else self.span_attributes.values() + ), + *self.default_attributes, + ): + if not Span.has_extension(dst): + Span.set_extension(dst, default=None) + + for ent in obj.get("entities") or (): + fragments = ( + [ + { + "begin": min(f["begin"] for f in ent["fragments"]), + "end": max(f["end"] for f in ent["fragments"]), + } + ] + if not self.split_fragments + else ent["fragments"] ) - if self.notes_as_span_attribute and ent["notes"]: - ent["attributes"][self.notes_as_span_attribute] = "|".join( - note["value"] for note in ent["notes"] + for fragment in fragments: + span = doc.char_span( + fragment["begin"], + fragment["end"], + label=ent["label"], + alignment_mode="expand", ) - for label, value in attributes.items(): - new_name = ( - self.span_attributes.get(label, None) - if self.span_attributes is not None - else label + attributes = ( + {} + if "attributes" not in ent + else {a["label"]: a["value"] for a in ent["attributes"]} + if isinstance(ent["attributes"], list) + else ent["attributes"] ) - if self.span_attributes is None and not Span.has_extension( - new_name - ): - Span.set_extension(new_name, default=None) - - if new_name: - value = True if value is None else value - if not self.keep_raw_attribute_values: - value = ( - True - if value in ("True", "true") - else False - if value in ("False", "false") - else value - ) - span._.set(new_name, value) - - spans.append(span) - - set_spans(doc, spans, span_setter=self.span_setter) - for attr, value in self.default_attributes.items(): - for span in spans: - if span._.get(attr) is None: - span._.set(attr, value) + if self.notes_as_span_attribute and ent["notes"]: + ent["attributes"][self.notes_as_span_attribute] = "|".join( + note["value"] for note in ent["notes"] + ) + for label, value in attributes.items(): + new_name = ( + self.span_attributes.get(label, None) + if self.span_attributes is not None + else label + ) + if self.span_attributes is None and not Span.has_extension( + new_name + ): + Span.set_extension(new_name, default=None) + + if new_name: + value = True if value is None else value + if not self.keep_raw_attribute_values: + value = ( + True + if value in ("True", "true") + else False + if value in ("False", "false") + else value + ) + span._.set(new_name, value) + + entities.setdefault(ent["entity_id"], []).append(span) + spans.append(span) + + set_spans(doc, spans, span_setter=self.span_setter) + for attr, value in self.default_attributes.items(): + for span in spans: + if span._.get(attr) is None: + span._.set(attr, value) + + for relation in obj.get("relations", []): + relation_label = ( + relation["relation_label"] + if "relation_label" in relation + else relation["label"] + ) + from_entity_id = relation["from_entity_id"] + to_entity_id = relation["to_entity_id"] + + for head in entities.get(from_entity_id, ()): + for tail in entities.get(to_entity_id, ()): + head._.rel.setdefault(relation_label, set()).add(tail) + except Exception: + raise ValueError(f"Error when processing {note_id}") return doc diff --git a/edsnlp/data/standoff.py b/edsnlp/data/standoff.py index bcecbf4bf..4bfe0d71b 100644 --- a/edsnlp/data/standoff.py +++ b/edsnlp/data/standoff.py @@ -32,6 +32,7 @@ REGEX_ATTRIBUTE = re.compile(r"^([AM]\d+)\t(.+?) ([TE]\d+)(?: (.+))?$") REGEX_EVENT = re.compile(r"^(E\d+)\t(.+)$") REGEX_EVENT_PART = re.compile(r"(\S+):([TE]\d+)") +REGEX_STATUS = re.compile(r"^(#\d+)\tStatus ([^\t]+)\t(.*)$") class BratParsingError(ValueError): @@ -71,6 +72,7 @@ def parse_standoff_file( entities = {} relations = [] events = {} + doc = {} with fs.open(txt_path, "r", encoding="utf-8") as f: text = f.read() @@ -178,6 +180,11 @@ def parse_standoff_file( "arguments": arguments, } elif line.startswith("#"): + match = REGEX_STATUS.match(line) + if match: + comment = match.group(3) + doc["status"] = comment + continue match = REGEX_NOTE.match(line) if match is None: raise BratParsingError(ann_file, line) @@ -201,6 +208,7 @@ def parse_standoff_file( "entities": list(entities.values()), "relations": relations, "events": list(events.values()), + **doc, } @@ -260,19 +268,19 @@ def dump_standoff_file( ) attribute_idx += 1 - # fmt: off - # if "relations" in doc: - # for i, relation in enumerate(doc["relations"]): - # entity_from = entities_ids[relation["from_entity_id"]] - # entity_to = entities_ids[relation["to_entity_id"]] - # print( - # "R{}\t{} Arg1:{} Arg2:{}\t".format( - # i + 1, str(relation["label"]), entity_from, - # entity_to - # ), - # file=f, - # ) - # fmt: on + # fmt: off + if "relations" in doc: + for i, relation in enumerate(doc["relations"]): + entity_from = entities_ids[relation["from_entity_id"]] + entity_to = entities_ids[relation["to_entity_id"]] + print( + "R{}\t{} Arg1:{} Arg2:{}\t".format( + i + 1, str(relation["label"]), entity_from, + entity_to + ), + file=f, + ) + # fmt: on class StandoffReader(FileBasedReader): diff --git a/edsnlp/extensions.py b/edsnlp/extensions.py index 7127afe5b..be8c87111 100644 --- a/edsnlp/extensions.py +++ b/edsnlp/extensions.py @@ -2,7 +2,7 @@ from datetime import date, datetime from dateutil.parser import parse as parse_date -from spacy.tokens import Doc +from spacy.tokens import Doc, Span if not Doc.has_extension("note_id"): Doc.set_extension("note_id", default=None) @@ -43,3 +43,6 @@ def get_note_datetime(doc): if not Doc.has_extension("birth_datetime"): Doc.set_extension("birth_datetime", default=None) + +if not Span.has_extension("rel"): + Span.set_extension("rel", default={}) diff --git a/edsnlp/metrics/relations.py b/edsnlp/metrics/relations.py new file mode 100644 index 000000000..4bbcdda0b --- /dev/null +++ b/edsnlp/metrics/relations.py @@ -0,0 +1,142 @@ +from collections import defaultdict +from itertools import product +from typing import Any, Optional + +from edsnlp import registry +from edsnlp.metrics import Examples, make_examples, prf +from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans +from edsnlp.utils.typing import AsList + + +def relations_scorer( + examples: Examples, + candidate_getter: AsList[RelationCandidateGetter], + micro_key: str = "micro", + filter_expr: Optional[str] = None, +): + """ + Scores the attributes predictions between a list of gold and predicted spans. + + Parameters + ---------- + examples : Examples + The examples to score, either a tuple of (golds, preds) or a list of + spacy.training.Example objects + candidate_getter : AsList[RelationCandidateGetter] + The candidate getters to use to extract the possible relations from the + documents. Each candidate getter should be a dictionary with the keys + "head", "tail", and "labels". The "head" and "tail" keys should be + SpanGetterArg objects, and the "labels" key should be a list of strings + for these head-tail pairs. + micro_key : str + The key to use to store the micro-averaged results for spans of all types + filter_expr : Optional[str] + The filter expression to use to filter the documents + + Returns + ------- + Dict[str, float] + """ + examples = make_examples(examples) + if filter_expr is not None: + filter_fn = eval(f"lambda doc: {filter_expr}") + examples = [eg for eg in examples if filter_fn(eg.reference)] + # annotations: {label -> preds, golds, pred_with_probs} + annotations = defaultdict(lambda: (set(), set(), dict())) + annotations[micro_key] = (set(), set(), dict()) + total_pred_count = 0 + total_gold_count = 0 + + for candidate in candidate_getter: + head_getter = candidate["head"] + tail_getter = candidate["tail"] + labels = candidate["labels"] + symmetric = candidate.get("symmetric") or False + label_filter = candidate.get("label_filter") + for eg_idx, eg in enumerate(examples): + pred_heads = [ + ((h.start, h.end, h.label_), h) + for h in get_spans(eg.predicted, head_getter) + ] + pred_tails = [ + ((t.start, t.end, t.label_), t) + for t in get_spans(eg.predicted, tail_getter) + ] + for (h_key, head), (t_key, tail) in product(pred_heads, pred_tails): + if ( + label_filter is not None + and head.label_ not in label_filter + or tail.label_ not in label_filter + ): + continue + total_pred_count += 1 + for label in labels: + if ( + tail in head._.rel.get(label, ()) + or symmetric + and head in tail._.rel.get(label, ()) + ): + if symmetric and h_key > t_key: + h_key, t_key = t_key, h_key + annotations[label][0].add((eg_idx, h_key, t_key, label)) + annotations[micro_key][0].add((eg_idx, h_key, t_key, label)) + + gold_heads = [ + ((h.start, h.end, h.label_), h) + for h in get_spans(eg.reference, head_getter) + ] + gold_tails = [ + ((t.start, t.end, t.label_), t) + for t in get_spans(eg.reference, tail_getter) + ] + for (h_key, head), (t_key, tail) in product(gold_heads, gold_tails): + total_gold_count += 1 + for label in labels: + if ( + tail in head._.rel.get(label, ()) + or symmetric + and head in tail._.rel.get(label, ()) + ): + if symmetric and h_key > t_key: + h_key, t_key = t_key, h_key + annotations[label][1].add((eg_idx, h_key, t_key, label)) + annotations[micro_key][1].add((eg_idx, h_key, t_key, label)) + + if total_pred_count != total_gold_count: + raise ValueError( + f"Number of predicted and gold candidate pairs differ: {total_pred_count} " + f"!= {total_gold_count}. Make sure that you are running your span " + "attribute classification pipe on the gold annotations, and not spans " + "predicted by another NER pipe in your model." + ) + + return { + name: { + **prf(pred, gold), + # "ap": average_precision(pred_with_prob, gold), + } + for name, (pred, gold, pred_with_prob) in annotations.items() + } + + +@registry.metrics.register("eds.relations") +class RelationsMetric: + def __init__( + self, + candidate_getter: AsList[RelationCandidateGetter], + micro_key: str = "micro", + filter_expr: Optional[str] = None, + ): + self.candidate_getter = candidate_getter + self.micro_key = micro_key + self.filter_expr = filter_expr + + __init__.__doc__ = relations_scorer.__doc__ + + def __call__(self, *examples: Any): + return relations_scorer( + examples, + candidate_getter=self.candidate_getter, + micro_key=self.micro_key, + filter_expr=self.filter_expr, + ) diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index c5055e95f..19ff39d0f 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -75,6 +75,7 @@ from .qualifiers.reported_speech.factory import create_component as reported_speech from .qualifiers.reported_speech.factory import create_component as rspeech from .trainable.ner_crf.factory import create_component as ner_crf + from .trainable.relation_detector_ffn.factory import create_component as relation_detector_ffn from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser from .trainable.extractive_qa.factory import create_component as extractive_qa from .trainable.span_classifier.factory import create_component as span_classifier diff --git a/edsnlp/pipes/base.py b/edsnlp/pipes/base.py index b17f9ba66..37b549e7f 100644 --- a/edsnlp/pipes/base.py +++ b/edsnlp/pipes/base.py @@ -14,6 +14,7 @@ from edsnlp.core import PipelineProtocol from edsnlp.core.registries import DraftPipe from edsnlp.utils.span_getters import ( + RelationCandidateGetter, SpanGetter, # noqa: F401 SpanGetterArg, # noqa: F401 SpanSetter, @@ -23,6 +24,7 @@ validate_span_getter, # noqa: F401 validate_span_setter, ) +from edsnlp.utils.typing import AsList def value_getter(span: Span): @@ -203,3 +205,37 @@ def qualifiers(self): # pragma: no cover @qualifiers.setter def qualifiers(self, value): # pragma: no cover self.attributes = value + + +class BaseRelationDetectorComponent(BaseComponent, abc.ABC): + def __init__( + self, + nlp: PipelineProtocol = None, + name: str = None, + *args, + candidate_getter: AsList[RelationCandidateGetter], + **kwargs, + ): + super().__init__(nlp, name, *args, **kwargs) + self.candidate_getter = [ + { + "head": validate_span_getter(candidate["head"]), + "tail": validate_span_getter(candidate["tail"]), + "labels": candidate["labels"], + "label_filter": { + head: set(tail_labels) + for head, tail_labels in candidate["label_filter"].items() + } + if candidate.get("label_filter") + else None, + "symmetric": candidate.get("symmetric") or False, + } + for candidate in candidate_getter + ] + self.labels = sorted( + { + label + for candidate in self.candidate_getter + for label in candidate["labels"] + } + ) diff --git a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py index 5e58cd9bb..f41a4486f 100644 --- a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py +++ b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py @@ -203,7 +203,9 @@ def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput: "embeddings": batch["begins"].with_data(span_embeds), } - embeds = self.embedding(batch["embedding"])["embeddings"] + embeds = self.embedding(batch["embedding"])["embeddings"].refold( + ["context", "word"] + ) _, n_words, dim = embeds.shape device = embeds.device diff --git a/edsnlp/pipes/trainable/relation_detector_ffn/__init__.py b/edsnlp/pipes/trainable/relation_detector_ffn/__init__.py new file mode 100644 index 000000000..549d2fc77 --- /dev/null +++ b/edsnlp/pipes/trainable/relation_detector_ffn/__init__.py @@ -0,0 +1 @@ +from .factory import create_component diff --git a/edsnlp/pipes/trainable/relation_detector_ffn/factory.py b/edsnlp/pipes/trainable/relation_detector_ffn/factory.py new file mode 100644 index 000000000..066f87704 --- /dev/null +++ b/edsnlp/pipes/trainable/relation_detector_ffn/factory.py @@ -0,0 +1,9 @@ +from edsnlp import registry + +from .relation_detector_ffn import RelationDetectorFFN + +create_component = registry.factory.register( + "eds.relation_detector_ffn", + assigns=[], + deprecated=[], +)(RelationDetectorFFN) diff --git a/edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py b/edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py new file mode 100644 index 000000000..3c9fd803f --- /dev/null +++ b/edsnlp/pipes/trainable/relation_detector_ffn/relation_detector_ffn.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import logging +import warnings +from collections import defaultdict +from itertools import product +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Set, +) + +import torch +import torch.nn.functional as F +from spacy.tokens import Doc, Span +from typing_extensions import TypedDict + +from edsnlp.core import PipelineProtocol +from edsnlp.core.torch_component import BatchInput, BatchOutput, TorchComponent +from edsnlp.pipes.base import BaseRelationDetectorComponent +from edsnlp.pipes.trainable.embeddings.typing import ( + SpanEmbeddingComponent, + WordEmbeddingComponent, +) +from edsnlp.utils.span_getters import RelationCandidateGetter, get_spans +from edsnlp.utils.typing import AsList + + +def make_ranges(starts, ends): + """ + Efficient computation and concat of ranges from starts and ends. + + Examples + -------- + ```{ .python .no-check } + + starts = torch.tensor([0, 3, 6]) + ends = torch.tensor([2, 8, 8]) + make_ranges(starts, ends) + # <---> <-----------> <---> + # tensor([0, 1, 3, 4, 5, 6, 7, 6, 7]) + ``` + + Parameters + ---------- + starts: torch.Tensor + ends: torch.Tensor + + Returns + ------- + torch.Tensor + """ + assert starts.shape == ends.shape + if 0 in ends.shape: + return ends + sizes = ends - starts + mask = sizes > 0 + offsets = sizes.cumsum(0) + offsets = offsets.roll(1) + res = torch.ones(offsets[0], dtype=torch.long, device=starts.device) + offsets[0] = 0 + masked_offsets = offsets[mask] + starts = starts[mask] + ends = ends[mask] + res[masked_offsets] = starts + res[masked_offsets[1:]] -= ends[:-1] - 1 + return res.cumsum(0), offsets + + +logger = logging.getLogger(__name__) + +FrameBatchInput = TypedDict( + "FrameBatchInput", + { + "span_embedding": BatchInput, + "word_embedding": BatchInput, + "rel_head_idx": torch.Tensor, + "rel_tail_idx": torch.Tensor, + "rel_doc_idx": torch.Tensor, + "rel_labels": torch.Tensor, + }, +) +""" +span_embedding: torch.FloatTensor + Token embeddings to predict the tags from +""" + + +class MLP(torch.nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, dropout_p: float = 0.0 + ): + super().__init__() + self.hidden = torch.nn.Linear(input_dim, hidden_dim) + self.output = torch.nn.Linear(hidden_dim, output_dim) + self.dropout = torch.nn.Dropout(dropout_p) + + def forward(self, x): + x = self.dropout(x) + x = self.hidden(x) + x = F.gelu(x) + x = self.output(x) + return x + + +class RelationDetectorFFN( + TorchComponent[BatchOutput, FrameBatchInput], + BaseRelationDetectorComponent, +): + def __init__( + self, + nlp: Optional[PipelineProtocol] = None, + name: str = "relation_detector_ffn", + *, + span_embedding: SpanEmbeddingComponent, + word_embedding: WordEmbeddingComponent, + candidate_getter: AsList[RelationCandidateGetter], + hidden_size: int = 128, + dropout_p: float = 0.0, + use_inter_words: bool = True, + ): + super().__init__( + nlp=nlp, + name=name, + candidate_getter=candidate_getter, + ) + self.span_embedding = span_embedding + self.word_embedding = word_embedding + self.use_inter_words = use_inter_words + + embed_size = self.span_embedding.output_size * 2 + ( + self.word_embedding.output_size if use_inter_words else 0 + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + # self.head_projection = torch.nn.Linear(hidden_size, hidden_size) + # self.tail_projection = torch.nn.Linear(hidden_size, hidden_size) + self.mlp = MLP(embed_size, hidden_size, hidden_size, dropout_p) + self.classifier = torch.nn.Linear(hidden_size, len(self.labels)) + + @property + def span_getter(self): + return self.embedding.span_getter + + def to_disk(self, path, *, exclude=set()): + repr_id = object.__repr__(self) + if repr_id in exclude: + return + return super().to_disk(path, exclude=exclude) + + def from_disk(self, path, exclude=tuple()): + repr_id = object.__repr__(self) + if repr_id in exclude: + return + self.set_extensions() + super().from_disk(path, exclude=exclude) + + def set_extensions(self): + super().set_extensions() + if not Span.has_extension("rel"): + Span.set_extension("rel", default={}) + if not Span.has_extension("scope"): + Span.set_extension("scope", default=None) + + def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]): + super().post_init(gold_data, exclude=exclude) + + def preprocess(self, doc: Doc, supervised: int = False) -> Dict[str, Any]: + rel_head_idx = [] + rel_tail_idx = [] + rel_labels = [] + rel_getter_indices = [] + + all_spans = defaultdict(lambda: len(all_spans)) + + for getter_idx, getter in enumerate(self.candidate_getter): + head_spans = list(get_spans(doc, getter["head"])) + tail_spans = list(get_spans(doc, getter["tail"])) + lab_filter = getter.get("label_filter") + assert lab_filter is not None + for head, tail in product(head_spans, tail_spans): + if lab_filter and head in lab_filter and tail not in lab_filter[head]: + continue + rel_head_idx.append(all_spans[head]) + rel_tail_idx.append(all_spans[tail]) + rel_getter_indices.append(getter_idx) + if supervised: + rel_labels.append( + [ + ( + tail in head._.rel.get(lab, ()) + or ( + getter["symmetric"] + and head in tail._.rel.get(lab, ()) + ) + ) + for lab in self.labels + ] + ) + + result = { + "num_spans": len(all_spans), + "rel_heads": rel_head_idx, + "rel_tails": rel_tail_idx, + "word_embedding": self.word_embedding.preprocess(doc, contexts=None), + "span_embedding": self.span_embedding.preprocess( + doc, + spans=list(all_spans), + contexts=None, + ), + "$spans": list(all_spans.keys()), + "$getter": rel_getter_indices, + "stats": { + "relation_candidates": len(rel_head_idx), + }, + } + if supervised: + result["rel_labels"] = rel_labels + + return result + + def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]: + return self.preprocess(doc, supervised=True) + + def collate(self, batch: Dict[str, Any]) -> FrameBatchInput: + rel_heads = [] + rel_tails = [] + rel_doc_idx = [] + offset = 0 + for doc_idx, feats in enumerate( + zip( + batch["rel_heads"], + batch["rel_tails"], + batch["num_spans"], + ) + ): + doc_rel_heads, doc_rel_tails, doc_num_spans = feats + rel_heads.extend([x + offset for x in doc_rel_heads]) + rel_tails.extend([x + offset for x in doc_rel_tails]) + rel_doc_idx.extend([doc_idx] * len(doc_rel_heads)) + offset += batch["num_spans"][doc_idx] + + collated: FrameBatchInput = { + "rel_head_idx": torch.as_tensor(rel_heads, dtype=torch.long), + "rel_tail_idx": torch.as_tensor(rel_tails, dtype=torch.long), + "rel_doc_idx": torch.as_tensor(rel_doc_idx, dtype=torch.long), + "span_embedding": self.span_embedding.collate(batch["span_embedding"]), + "word_embedding": self.word_embedding.collate(batch["word_embedding"]), + "stats": {"relation_candidates": len(rel_heads)}, + } + + if "rel_labels" in batch: + collated["rel_labels"] = torch.as_tensor( + [labs for doc_labels in batch["rel_labels"] for labs in doc_labels] + ).view(-1, self.classifier.out_features) + return collated + + def compute_inter_span_embeds(self, word_embeds, begins, ends, head_idx, tail_idx): + _, n_words, dim = word_embeds.shape + if 0 in begins.shape or 0 in head_idx.shape: + return torch.zeros( + 0, dim, dtype=word_embeds.dtype, device=word_embeds.device + ) + + flat_begins = torch.minimum(ends[head_idx], ends[tail_idx]) + flat_ends = torch.maximum(begins[head_idx], begins[tail_idx]) + flat_begins, flat_ends = ( + torch.minimum(flat_begins, flat_ends), + torch.maximum(flat_begins, flat_ends), + ) + flat_embeds = word_embeds.view(-1, dim) + flat_indices, flat_offsets = make_ranges(flat_begins, flat_ends) + flat_offsets[0] = 0 + inter_span_embeds = torch.nn.functional.embedding_bag( # type: ignore + input=flat_indices, + weight=flat_embeds, + offsets=flat_offsets, + mode="mean", + ) + return inter_span_embeds + + # noinspection SpellCheckingInspection + def forward(self, batch: FrameBatchInput) -> BatchOutput: + """ + Apply the span classifier module to the document embeddings and given spans to: + - compute the loss + - and/or predict the labels of spans + + Parameters + ---------- + batch: SpanQualifierBatchInput + The input batch + + Returns + ------- + BatchOutput + """ + word_embeds = self.word_embedding(batch["word_embedding"])["embeddings"] + span_embeds = self.span_embedding(batch["span_embedding"])["embeddings"] + + n_words = word_embeds.size(-2) + spans = batch["span_embedding"] + flat_begins = n_words * spans["sequence_idx"] + spans["begins"].as_tensor() + flat_ends = n_words * spans["sequence_idx"] + spans["ends"].as_tensor() + if self.use_inter_words: + inter_span_embeds = self.compute_inter_span_embeds( + word_embeds=word_embeds, + begins=flat_begins, + ends=flat_ends, + head_idx=batch["rel_head_idx"], + tail_idx=batch["rel_tail_idx"], + ) + rel_embeds = torch.cat( + [ + span_embeds[batch["rel_head_idx"]], + inter_span_embeds, + span_embeds[batch["rel_tail_idx"]], + ], + dim=-1, + ) + else: + rel_embeds = torch.cat( + [ + span_embeds[batch["rel_head_idx"]], + span_embeds[batch["rel_tail_idx"]], + ], + dim=-1, + ) + rel_embeds = self.mlp(rel_embeds) + logits = self.classifier(rel_embeds) + + losses = pred = None + if "rel_labels" in batch: + losses = [] + target = batch["rel_labels"].float() + num_relation_candidates = batch["stats"]["relation_candidates"] + losses.append( + F.binary_cross_entropy_with_logits(logits, target, reduction="sum") + / num_relation_candidates + ) + else: + pred = logits > 0 + + return { + "loss": sum(losses) if losses is not None else None, + "pred": pred, + } + + def postprocess( + self, + docs: List[Doc], + results: BatchOutput, + inputs: List[Dict[str, Any]], + ): + """ + Extract predicted relations from forward's "pred" field (boolean tensor) + and annotated them on the head._.rel attribute (dictionary) + Parameters + ---------- + docs: Sequence[Doc] + List of documents to update + results: BatchOutput + Batch of predictions, as returned by the forward method + inputs: BatchInput + List of preprocessed features, as returned by the preprocess method + + Returns + ------- + """ + all_heads = [p["$spans"][idx] for p in inputs for idx in p["rel_heads"]] + all_tails = [p["$spans"][idx] for p in inputs for idx in p["rel_tails"]] + getter_indices = [idx for p in inputs for idx in p["$getter"]] + for pair_idx, label_idx in results["pred"].nonzero(as_tuple=False).tolist(): + head = all_heads[pair_idx] + tail = all_tails[pair_idx] + label = self.labels[label_idx] + head._.rel.setdefault(label, set()).add(tail) + if self.candidate_getter[getter_indices[pair_idx]]["symmetric"]: + tail._.rel.setdefault(label, set()).add(head) + return docs diff --git a/edsnlp/training/trainer.py b/edsnlp/training/trainer.py index 05bc935b9..5fe2a5f67 100644 --- a/edsnlp/training/trainer.py +++ b/edsnlp/training/trainer.py @@ -36,6 +36,7 @@ from edsnlp.metrics.span_attributes import SpanAttributeMetric from edsnlp.pipes.base import ( BaseNERComponent, + BaseRelationDetectorComponent, BaseSpanAttributeClassifierComponent, ) from edsnlp.utils.batching import BatchSizeArg, stat_batchify @@ -44,6 +45,7 @@ from edsnlp.utils.span_getters import get_spans from edsnlp.utils.typing import AsList +from ..metrics.relations import RelationsMetric from .optimizer import LinearSchedule, ScheduledOptimizer LOGGER_FIELDS = { @@ -189,6 +191,35 @@ def __call__(self, nlp: Pipeline, docs: Iterable[Any]): for name, scorer in span_attr_scorers.items(): scores[name] = scorer(docs, qlf_preds) + # Relations + rel_pipes = [ + name + for name, pipe in nlp.pipeline + if isinstance(pipe, BaseRelationDetectorComponent) + ] + rel_scorers = { + name: scorers.pop(name) + for name in list(scorers) + if isinstance(scorers[name], RelationsMetric) + } + if rel_pipes and rel_scorers: + clean_rel_docs = [d.copy() for d in tqdm(docs, desc="Copying docs")] + for doc in clean_rel_docs: + for name in rel_pipes: + pipe = nlp.get_pipe(name) + for candidate_getter in pipe.candidate_getter: + for span in ( + *get_spans(doc, candidate_getter["head"]), + *get_spans(doc, candidate_getter["tail"]), + ): + for label in nlp.get_pipe(name).labels: + if label in span._.rel: + span._.rel[label].clear() + with nlp.select_pipes(disable=ner_pipes): + rel_preds = list(nlp.pipe(tqdm(clean_rel_docs, desc="Predicting"))) + for name, scorer in rel_scorers.items(): + scores[name] = scorer(docs, rel_preds) + # Custom scorers for name, scorer in scorers.items(): pred_docs = [d.copy() for d in tqdm(docs, desc="Copying docs")] diff --git a/edsnlp/utils/span_getters.py b/edsnlp/utils/span_getters.py index b1b8a156b..25297479e 100644 --- a/edsnlp/utils/span_getters.py +++ b/edsnlp/utils/span_getters.py @@ -8,12 +8,14 @@ List, Optional, Sequence, + Set, Tuple, Union, ) from pydantic import NonNegativeInt from spacy.tokens import Doc, Span +from typing_extensions import NotRequired, TypedDict from edsnlp import registry from edsnlp.utils.filter import filter_spans @@ -45,8 +47,12 @@ def get_spans(doc, span_getter): for key, span_filter in span_getter.items(): if key == "*": candidates = (span for group in doc.spans.values() for span in group) + elif key == "ents": + candidates = doc.ents + elif key == "doc": + candidates = (doc[:],) else: - candidates = doc.spans.get(key, ()) if key != "ents" else doc.ents + candidates = doc.spans.get(key, ()) if span_filter is True: yield from candidates else: @@ -67,8 +73,12 @@ def get_spans_with_group(doc, span_getter): candidates = ( (span, group) for group in doc.spans.values() for span in group ) + elif key == "ents": + candidates = ((span, key) for span in doc.ents) + elif key == "doc": + candidates = ((doc[:], "doc"),) else: - candidates = doc.spans.get(key, ()) if key != "ents" else doc.ents + candidates = doc.spans.get(key, ()) candidates = ((span, key) for span in candidates) if span_filter is True: yield from candidates @@ -321,3 +331,15 @@ def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]: end = max(end, max_end_sent) return span.doc[start:end] + + +RelationCandidateGetter = TypedDict( + "RelationCandidateGetter", + { + "head": SpanGetterArg, + "tail": SpanGetterArg, + "labels": AsList[str], + "label_filter": NotRequired[Dict[str, Set[str]]], + "symmetric": Optional[bool], + }, +) diff --git a/pyproject.toml b/pyproject.toml index 59577e3b3..8df5a4ede 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,16 +248,17 @@ where = ["."] # edsnlp will look both in the above dict and in the one below. [project.entry-points."edsnlp_factories"] # Trainable -"eds.transformer" = "edsnlp.pipes.trainable.embeddings.transformer.factory:create_component" -"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" -"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" -"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" -"eds.extractive_qa" = "edsnlp.pipes.trainable.extractive_qa.factory:create_component" -"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" -"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" -"eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" -"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" -"eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" +"eds.transformer" = "edsnlp.pipes.trainable.embeddings.transformer.factory:create_component" +"eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" +"eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" +"eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.extractive_qa" = "edsnlp.pipes.trainable.extractive_qa.factory:create_component" +"eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" +"eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" +"eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" +"eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" +"eds.relation_detector_ffn" = "edsnlp.pipes.trainable.relation_detector_ffn.factory:create_component" [project.entry-points."spacy_scorers"] "eds.ner_exact" = "edsnlp.metrics.ner:NerExactMetric" @@ -265,6 +266,7 @@ where = ["."] "eds.ner_overlap" = "edsnlp.metrics.ner:NerOverlapMetric" "eds.span_attributes" = "edsnlp.metrics.span_attributes:SpanAttributeMetric" "eds.dep_parsing" = "edsnlp.metrics.dep_parsing:DependencyParsingMetric" +"eds.relations" = "edsnlp.metrics.relations:RelationsMetric" # Deprecated "eds.ner_exact_metric" = "edsnlp.metrics.ner:NerExactMetric" diff --git a/tests/data/test_converters.py b/tests/data/test_converters.py index c53ca1ef7..9e5d42396 100644 --- a/tests/data/test_converters.py +++ b/tests/data/test_converters.py @@ -83,6 +83,14 @@ def test_read_standoff_dict(blank_nlp): "label": "test", }, ], + "relations": [ + { + "relation_id": "R1", + "relation_label": "linked", + "from_entity_id": 1, + "to_entity_id": 0, + } + ], } doc = get_dict2doc_converter( "standoff", @@ -98,6 +106,7 @@ def test_read_standoff_dict(blank_nlp): assert doc.ents[0].text == "This" assert doc.ents[0]._.negation is True assert doc.ents[1]._.negation is False + assert doc.ents[1]._.rel["linked"] == {doc.ents[0]} def test_write_omop_dict(blank_nlp): diff --git a/tests/training/dataset_2/sample-1.ann b/tests/training/dataset_2/sample-1.ann new file mode 100644 index 000000000..d5ee1745f --- /dev/null +++ b/tests/training/dataset_2/sample-1.ann @@ -0,0 +1,6 @@ +#1000 Status #1000 CHECKED +T1 date 6 18 19 juin 1987 +T2 covid 52 57 covid +T3 date 69 84 12 octobre 1983 +T3 covid 103 108 covid +R1 linked Arg1:T2 Arg2:T1 diff --git a/tests/training/dataset_2/sample-1.txt b/tests/training/dataset_2/sample-1.txt new file mode 100644 index 000000000..e332852e8 --- /dev/null +++ b/tests/training/dataset_2/sample-1.txt @@ -0,0 +1,2 @@ +CR du 19 juin 1987. La patiente a été diagnostiquée covid positif le 12 octobre 1983. +Autre occurrence covid mentionnée sans date précise. diff --git a/tests/training/dataset_2/sample-2.ann b/tests/training/dataset_2/sample-2.ann new file mode 100644 index 000000000..bd4ab3c7e --- /dev/null +++ b/tests/training/dataset_2/sample-2.ann @@ -0,0 +1,4 @@ +T1 date 11 24 29 avril 2020 +T2 date 62 70 30 avril +T3 covid 101 106 covid +R1 linked Arg1:T3 Arg2:T1 diff --git a/tests/training/dataset_2/sample-2.txt b/tests/training/dataset_2/sample-2.txt new file mode 100644 index 000000000..9d0c579b5 --- /dev/null +++ b/tests/training/dataset_2/sample-2.txt @@ -0,0 +1 @@ +On est le 29 avril 2020, et j'ai rendez vous à l'aéroport le 30 avril et je n'ai toujours pas eu le covid, je croise les doigts ! diff --git a/tests/training/rel_config.cfg b/tests/training/rel_config.cfg new file mode 100644 index 000000000..7a2dc2ea7 --- /dev/null +++ b/tests/training/rel_config.cfg @@ -0,0 +1,79 @@ +[nlp] +lang = "eds" +pipeline = [ + "normalizer", + "sentencizer", + "covid", + "dates", + "relations", + ] +batch_size = 2 +components = ${components} +tokenizer = {"@tokenizers": "eds.tokenizer"} + +[components.normalizer] +@factory = "eds.normalizer" + +[components.sentencizer] +@factory = "eds.sentences" + +[components.covid] +@factory = "eds.covid" + +[components.dates] +@factory = "eds.dates" + +# Relations component is: +# - a span relation detector, that classifies pairs spans embedded by... +# - a span pooler, that pools words embedded by... +# - a text cnn, that re-contextualizes words embedded by... +# - a transformer +[components.relations] +@factory = "eds.relation_detector_ffn" +head_getter = {"ents": "covid"} +tail_getter = {"ents": "date"} +labels = ["linked"] +symmetric = true + +[components.relations.word_embedding] +@factory = "eds.text_cnn" +kernel_sizes = [3] + +[components.relations.word_embedding.embedding] +@factory = "eds.transformer" +model = "hf-internal-testing/tiny-bert" +window = 128 +stride = 96 + +[components.relations.span_embedding] +@factory = "eds.span_pooler" +embedding = ${components.relations.word_embedding} + +[scorer.rel] +@metrics = "eds.relations" +head_getter = ${components.relations.head_getter} +tail_getter = ${components.relations.tail_getter} +labels = ${components.relations.labels} + +[train] +nlp = ${nlp} +max_steps = 50 +validation_interval = ${train.max_steps//10} +warmup_rate = 0 +batch_size = 2 samples +transformer_lr = 0 +task_lr = 1e-3 +scorer = ${scorer} + +[train.train_data] +randomize = true +max_length = 100 +multi_sentence = true +[train.train_data.reader] +@readers = "standoff" +path = "./dataset_2/" + +[train.val_data] +[train.val_data.reader] +@readers = "standoff" +path = "./dataset_2/" diff --git a/tests/training/rel_config.yml b/tests/training/rel_config.yml new file mode 100644 index 000000000..1bdc3fd01 --- /dev/null +++ b/tests/training/rel_config.yml @@ -0,0 +1,90 @@ +# 🤖 PIPELINE DEFINITION +nlp: + "@core": pipeline + + lang: eds + + components: + normalizer: + '@factory': eds.normalizer + + sentencizer: + '@factory': eds.sentences + + covid: + '@factory': eds.covid + + relations: + '@factory': "eds.relation_detector_ffn" + candidate_getter: + head: { "ents": "covid" } + tail: { "ents": "date" } + labels: [ "linked" ] + symmetric: true + + word_embedding: + '@factory': eds.text_cnn + kernel_sizes: [ 3 ] + + embedding: + '@factory': eds.transformer + model: hf-internal-testing/tiny-bert + window: 128 + stride: 96 + + span_embedding: + '@factory': eds.span_pooler + embedding: ${nlp.components.relations.word_embedding} + +# 📈 SCORERS +scorer: + speed: true + batch_size: 2 docs + rel: + "@metrics": eds.relations + candidate_getter: ${nlp.components.relations.candidate_getter} + +# 🎛️ OPTIMIZER +# (disabled to test the default optimizer) +# optimizer: +# "@optimizers": adam +# groups: +# "*.transformer.*": +# lr: 1e-3 +# schedules: +# "@schedules": linear +# "warmup_rate": 0.1 +# "start_value": 0 +# "*": +# lr: 1e-3 +# schedules: +# "@schedules": linear +# "warmup_rate": 0.1 +# "start_value": 1e-3 + +# 📚 DATA +train_data: + - data: + '@readers': standoff + path: ./dataset_2/ + converter: + - '@factory': eds.standoff_dict2doc + shuffle: dataset + batch_size: 1 docs + +val_data: + '@readers': standoff + path: ./dataset_2/ + converter: + - '@factory': eds.standoff_dict2doc + +# 🚀 TRAIN SCRIPT OPTIONS +train: + nlp: ${ nlp } + train_data: ${ train_data } + val_data: ${ val_data } + max_steps: 40 + validation_interval: 10 + max_grad_norm: 1.0 + scorer: ${ scorer } + num_workers: 1 diff --git a/tests/training/test_train.py b/tests/training/test_train.py index 8a1e4f03b..53c1617b4 100644 --- a/tests/training/test_train.py +++ b/tests/training/test_train.py @@ -208,6 +208,22 @@ def test_dep_parser_train(run_in_test_dir, tmp_path): assert last_scores["dep"]["las"] >= 0.4 +def test_rel_train(run_in_test_dir, tmp_path): + set_seed(42) + config = Config.from_disk("rel_config.yml") + shutil.rmtree(tmp_path, ignore_errors=True) + kwargs = Config.resolve(config["train"], registry=registry, root=config) + nlp = train(**kwargs, output_dir=tmp_path, cpu=True) + scorer = GenericScorer(**kwargs["scorer"]) + val_data = kwargs["val_data"] + last_scores = scorer(nlp, val_data) + + # Check empty doc + nlp("") + + assert last_scores["rel"]["micro"]["f"] >= 0.4 + + def test_optimizer(): net = torch.nn.Linear(10, 10) optim = ScheduledOptimizer( From 795852c8bdca3f9622d9ce1e96b4b23ac9df9e9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Tue, 17 Dec 2024 12:10:09 +0100 Subject: [PATCH 6/6] fix: dedup contexts in eds.span_pooler --- .../embeddings/span_pooler/span_pooler.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py index f41a4486f..0cbb69030 100644 --- a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py +++ b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py @@ -123,15 +123,18 @@ def preprocess( begins = [] ends = [] - contexts_to_idx = {span: i for i, span in enumerate(contexts)} + contexts_to_idx = {} + for ctx in contexts: + contexts_to_idx[ctx] = len(contexts_to_idx) + dedup_contexts = sorted(contexts_to_idx, key=contexts_to_idx.get) assert not pre_aligned or len(spans) == len(contexts), ( "When `pre_aligned` is True, the number of spans and contexts must be the " "same." ) aligned_contexts = ( - [[c] for c in contexts] + [[c] for c in dedup_contexts] if pre_aligned - else align_spans(contexts, spans, sort_by_overlap=True) + else align_spans(dedup_contexts, spans, sort_by_overlap=True) ) for i, (span, ctx) in enumerate(zip(spans, aligned_contexts)): if len(ctx) == 0 or ctx[0].start > span.start or ctx[0].end < span.end: @@ -143,12 +146,16 @@ def preprocess( sequence_idx.append(contexts_to_idx[ctx[0]]) begins.append(span.start - start) ends.append(span.end - start) + assert begins[-1] >= 0, f"Begin offset is negative: {span.text}" + assert ends[-1] <= len(ctx[0]), f"End offset is out of bounds: {span.text}" return { "begins": begins, "ends": ends, "sequence_idx": sequence_idx, - "num_sequences": len(contexts), - "embedding": self.embedding.preprocess(doc, contexts=contexts, **kwargs), + "num_sequences": len(dedup_contexts), + "embedding": self.embedding.preprocess( + doc, contexts=dedup_contexts, **kwargs + ), "stats": {"spans": len(begins)}, }