feat: Allow LoRA to be merged into weights
#12
by
Markus28
- opened
- modeling_lora.py +16 -0
modeling_lora.py
CHANGED
|
@@ -199,6 +199,12 @@ class LoRAParametrization(nn.Module):
|
|
| 199 |
if isinstance(layer, LoRAParametrization):
|
| 200 |
layer.current_task = task_idx
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
class BertLoRA(BertPreTrainedModel):
|
| 204 |
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
|
|
@@ -207,6 +213,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
| 208 |
else:
|
| 209 |
self.bert = bert
|
|
|
|
| 210 |
self._num_adaptions = config.num_loras
|
| 211 |
self._register_lora(self._num_adaptions)
|
| 212 |
self.main_params_trainable = False
|
|
@@ -230,6 +237,13 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 230 |
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
| 231 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
@classmethod
|
| 234 |
def from_pretrained(
|
| 235 |
cls,
|
|
@@ -265,6 +279,8 @@ class BertLoRA(BertPreTrainedModel):
|
|
| 265 |
|
| 266 |
@current_task.setter
|
| 267 |
def current_task(self, task_idx: Union[None, int]):
|
|
|
|
|
|
|
| 268 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 269 |
if self._task_idx != task_idx:
|
| 270 |
self._task_idx = task_idx
|
|
|
|
| 199 |
if isinstance(layer, LoRAParametrization):
|
| 200 |
layer.current_task = task_idx
|
| 201 |
|
| 202 |
+
@classmethod
|
| 203 |
+
def merge_lora_into_layer(cls, layer: nn.Module):
|
| 204 |
+
if hasattr(layer, "parametrizations"):
|
| 205 |
+
for attr_name in layer.parametrizations.keys():
|
| 206 |
+
parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
|
| 207 |
+
|
| 208 |
|
| 209 |
class BertLoRA(BertPreTrainedModel):
|
| 210 |
def __init__(self, config: JinaBertConfig, bert: Optional[BertModel] = None, add_pooling_layer=True):
|
|
|
|
| 213 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
| 214 |
else:
|
| 215 |
self.bert = bert
|
| 216 |
+
self._is_merged = False
|
| 217 |
self._num_adaptions = config.num_loras
|
| 218 |
self._register_lora(self._num_adaptions)
|
| 219 |
self.main_params_trainable = False
|
|
|
|
| 237 |
config = JinaBertConfig.from_pretrained(*args, **kwargs)
|
| 238 |
return cls(config, bert=bert, num_adaptions=num_adaptions)
|
| 239 |
|
| 240 |
+
def merge_lora(self):
|
| 241 |
+
"""Merges currently selected LoRA into main weights."""
|
| 242 |
+
if self._is_merged:
|
| 243 |
+
raise Exception('LoRA has already been merged, cannot merge again')
|
| 244 |
+
self._is_merged = True
|
| 245 |
+
self.apply(LoRAParametrization.merge_lora_into_layer)
|
| 246 |
+
|
| 247 |
@classmethod
|
| 248 |
def from_pretrained(
|
| 249 |
cls,
|
|
|
|
| 279 |
|
| 280 |
@current_task.setter
|
| 281 |
def current_task(self, task_idx: Union[None, int]):
|
| 282 |
+
if self._is_merged:
|
| 283 |
+
raise Exception('LoRA has been merged, cannot select new task')
|
| 284 |
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
| 285 |
if self._task_idx != task_idx:
|
| 286 |
self._task_idx = task_idx
|