diff --git a/adapter_bert.py b/adapter_bert.py index 78ef956..b2372a8 100644 --- a/adapter_bert.py +++ b/adapter_bert.py @@ -19,9 +19,9 @@ def __init__(self, BertOutput, config): self.adapter = HoulsbyAdapter(config.hidden_size) elif config.adapter == "conv_adapter": self.adapter = ConvAdapter(config.max_position_embeddings) - elif self.adapter == "AdapterBias": + elif config.adapter == "AdapterBias": self.adapter = AdapterBias(config.hidden_size) - elif self.adapter == "lora": + elif config.adapter == "lora": self.adapter = LoRA(config.hidden_size) else: raise NotImplementedError diff --git a/adapters.py b/adapters.py index 842f235..95b986f 100644 --- a/adapters.py +++ b/adapters.py @@ -1,5 +1,6 @@ import torch from torch import nn +import loralib as lora class HoulsbyAdapter(nn.Module): @@ -73,4 +74,4 @@ def __init__( self.lora_adapter = lora.Linear(input_size, input_size, r) def forward(self, x): - return self.lora_adapter(x) \ No newline at end of file + return self.lora_adapter(x)