Antuke commited on
Commit
ac8620f
·
1 Parent(s): 3a047be

adding missing dependency dlora

Browse files
Files changed (1) hide show
  1. src/dlora.py +668 -0
src/dlora.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # MTLoRA
3
+ # GitHub: https://github.com/scale-lab/MTLoRA
4
+ # Built upon Microsoft LoRA (https://github.com/microsoft/LoRA)
5
+ #
6
+ # Original file:
7
+ # Copyright (c) Microsoft Corporation. All rights reserved.
8
+ # Licensed under the MIT License
9
+ #
10
+ # Adapted file:
11
+ # Copyright (c) 2024 SCALE Lab, Brown University
12
+ # Licensed under the MIT License (see LICENSE for details)
13
+ # --------------------------------------------------------
14
+
15
+ r"""
16
+ Low Ranking Adaptation for LLMs scheme.
17
+
18
+ ┌───────────────────┐
19
+ ┆ h ┆
20
+ └───────────────────┘
21
+
22
+ |
23
+ +
24
+ / \
25
+ ┌─────────────────┐ ╭───────────────╮ Matrix initialization:
26
+ ┆ ┆ \ B / B = 0
27
+ ┆ pretrained ┆ \ r*d / A = N(0, sigma^2)
28
+ ┆ weights ┆ ╰─────────╯
29
+ ┆ ┆ | r | r - rank
30
+ ┆ W e R^(d*d) ┆ | ◀─────▶ |
31
+ ┆ ┆ ╭─────────╮
32
+ └─────────────────┘ / A \
33
+ ▲ / d*r \
34
+ \ ╰───────────────╯
35
+ \ ▲
36
+ \ /
37
+ \ /
38
+ ┌───────────────────┐
39
+ ┆ x ┆
40
+ └───────────────────┘
41
+
42
+ With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d,
43
+ we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates
44
+ for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of
45
+ course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen
46
+ pretrained weights and thus fine-tune the model.
47
+
48
+ The goal of this approach is to move weight updates into a separate matrix which is decomposed with
49
+ two matrices of a lower rank.
50
+ """
51
+
52
+ import math
53
+ from typing import Any, Dict, Tuple, Union, Mapping
54
+
55
+ import torch
56
+ import torch.nn as nn
57
+ from torch.nn import functional as F
58
+
59
+
60
+ class LoRALayer(nn.Module):
61
+ def __init__(self, r: int, lora_alpha: int, lora_dropout: float):
62
+ """Store LoRA specific attributes in a class.
63
+
64
+ Args:
65
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
66
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
67
+ lora_alpha: alpha is needed for scaling updates as alpha/r
68
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
69
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
70
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
71
+ """
72
+ super().__init__()
73
+ assert r >= 0
74
+ self.r = r
75
+ self.lora_alpha = lora_alpha
76
+ # Optional dropout
77
+ if lora_dropout > 0.0:
78
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
79
+ else:
80
+ self.lora_dropout = lambda x: x
81
+ # Mark the weight as unmerged
82
+ self.merged = False
83
+
84
+
85
+ class LoRALinear(LoRALayer):
86
+ # LoRA implemented in a dense layer
87
+ def __init__(
88
+ self,
89
+ # ↓ this part is for pretrained weights
90
+ in_features: int,
91
+ out_features: int,
92
+ # ↓ the remaining part is for LoRA
93
+ r: int = 0,
94
+ lora_alpha: int = 1,
95
+ lora_dropout: float = 0.0,
96
+ tasks=None,
97
+ **kwargs,
98
+ ):
99
+ """LoRA wrapper around linear class.
100
+
101
+ This class has three weight matrices:
102
+ 1. Pretrained weights are stored as `self.linear.weight`
103
+ 2. LoRA A matrix as `self.lora_A`
104
+ 3. LoRA B matrix as `self.lora_B`
105
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
106
+
107
+ Args:
108
+ in_features: number of input features of the pretrained weights
109
+ out_features: number of output features of the pretrained weights
110
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
111
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
112
+ lora_alpha: alpha is needed for scaling updates as alpha/r
113
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
114
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
115
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
116
+ """
117
+ super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
118
+ self.linear = torch.nn.Linear(
119
+ in_features, out_features, **kwargs)
120
+
121
+ # Actual trainable parameters
122
+ if r > 0:
123
+ self.lora_A = nn.Parameter(
124
+ self.linear.weight.new_zeros((r, in_features)))
125
+ self.lora_B = nn.Parameter(
126
+ self.linear.weight.new_zeros((out_features, r)))
127
+ self.scaling = self.lora_alpha / self.r
128
+ self.reset_parameters()
129
+
130
+ def reset_parameters(self):
131
+ """Reset all the weights, even including pretrained ones."""
132
+ if hasattr(self, "lora_A"):
133
+ # initialize A the same way as the default for nn.Linear and B to zero
134
+ # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
135
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
136
+ nn.init.zeros_(self.lora_B)
137
+
138
+ def merge(self):
139
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
140
+ if self.r > 0 and not self.merged:
141
+ # Merge the weights and mark it
142
+ self.linear.weight.data += (self.lora_B @
143
+ self.lora_A) * self.scaling
144
+ self.merged = True
145
+
146
+ def forward(self, x: torch.Tensor):
147
+ # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass;
148
+ # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
149
+ pretrained = self.linear(x)
150
+ if self.r == 0 or self.merged:
151
+ return pretrained
152
+ lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1)
153
+ @ self.lora_B.transpose(0, 1)) * self.scaling
154
+ return pretrained + lora
155
+
156
+
157
+ class MTLoRALinear(LoRALayer):
158
+ # LoRA implemented in a dense layer
159
+ def __init__(
160
+ self,
161
+ # ↓ this part is for pretrained weights
162
+ in_features: int,
163
+ out_features: int,
164
+ # ↓ the remaining part is for LoRA
165
+ r: Union[int, Mapping[str, int]] = 0,
166
+ lora_shared_scale: float = 1.0,
167
+ lora_task_scale: float = 1.0,
168
+ lora_dropout: float = 0.0,
169
+ tasks=None,
170
+ trainable_scale_shared=False,
171
+ trainable_scale_per_task=False,
172
+ shared_mode: str = 'matrix',
173
+ **kwargs,
174
+ ):
175
+ assert shared_mode in ['matrix', 'matrixv2',
176
+ 'add', 'addition', 'lora_only']
177
+ if shared_mode == 'add':
178
+ shared_mode = 'addition'
179
+ if shared_mode == 'lora_only':
180
+ tasks = None
181
+ has_tasks = tasks is not None
182
+ if not has_tasks:
183
+ if shared_mode not in ['matrix']:
184
+ shared_mode = 'matrix'
185
+
186
+ if isinstance(r, int):
187
+ r = {'shared': r}
188
+
189
+ super().__init__(
190
+ r=r['shared'], lora_alpha=lora_shared_scale, lora_dropout=lora_dropout)
191
+
192
+ self.linear = torch.nn.Linear(
193
+ in_features, out_features, **kwargs)
194
+
195
+ self.tasks = tasks
196
+ self.shared_mode = shared_mode
197
+ if r['shared'] > 0:
198
+ if has_tasks:
199
+ self.lora_tasks_A = nn.ParameterDict({
200
+ task: nn.Parameter(
201
+ self.linear.weight.new_zeros((r[task], in_features)))
202
+ for task in tasks
203
+ })
204
+ self.lora_tasks_B = nn.ParameterDict({
205
+ task: nn.Parameter(
206
+ self.linear.weight.new_zeros((out_features, r[task])))
207
+ for task in tasks
208
+ })
209
+ if trainable_scale_per_task:
210
+ self.lora_task_scale = nn.ParameterDict({
211
+ task: nn.Parameter(torch.FloatTensor(
212
+ [lora_task_scale]))
213
+ for task in tasks
214
+ })
215
+ else:
216
+ self.lora_task_scale = {task: lora_task_scale[task]
217
+ for task in tasks}
218
+ if self.shared_mode == 'addition':
219
+ assert has_tasks
220
+ self.lora_norm = nn.LayerNorm(out_features)
221
+ elif self.shared_mode == 'matrix' or self.shared_mode == 'matrixv2':
222
+ self.lora_shared_A = nn.Parameter(
223
+ self.linear.weight.new_zeros((r['shared'], in_features)))
224
+ self.lora_shared_B = nn.Parameter(
225
+ self.linear.weight.new_zeros((out_features, r['shared'])))
226
+ else:
227
+ raise NotImplementedError
228
+ if trainable_scale_shared:
229
+ self.lora_shared_scale = nn.Parameter(
230
+ torch.FloatTensor([lora_shared_scale]))
231
+ else:
232
+ self.lora_shared_scale = lora_shared_scale
233
+ self.reset_parameters()
234
+
235
+ def reset_parameters(self):
236
+ """Reset all the weights, even including pretrained ones."""
237
+ if hasattr(self, "lora_shared_A"):
238
+ # initialize A the same way as the default for nn.Linear and B to zero
239
+ # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314
240
+ nn.init.kaiming_uniform_(self.lora_shared_A, a=math.sqrt(5))
241
+ nn.init.zeros_(self.lora_shared_B)
242
+ if hasattr(self, "lora_tasks_A"):
243
+ for task in self.tasks:
244
+ nn.init.kaiming_uniform_(
245
+ self.lora_tasks_A[task], a=math.sqrt(5))
246
+ nn.init.zeros_(self.lora_tasks_B[task])
247
+
248
+ def merge(self):
249
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
250
+ raise NotImplementedError
251
+
252
+ def forward(self, x: torch.Tensor, x_tasks: Dict[str, torch.Tensor] = None):
253
+ # TODO: handle merging
254
+ pretrained = self.linear(x)
255
+ if self.r == 0:
256
+ return pretrained, None
257
+ x = self.lora_dropout(x)
258
+ if self.shared_mode == 'matrix':
259
+ lora = (x @ self.lora_shared_A.transpose(0, 1)
260
+ @ self.lora_shared_B.transpose(0, 1)) * self.lora_shared_scale
261
+ lora_tasks = {
262
+ task: pretrained + (x_task_input @ self.lora_tasks_A[task].transpose(
263
+ 0, 1) @ self.lora_tasks_B[task].transpose(0, 1) * self.lora_task_scale[task])
264
+ # Iterate over the items in the x_tasks dict that was PASSED IN
265
+ for task, x_task_input in x_tasks.items()
266
+ } if x_tasks is not None else None
267
+ elif self.shared_mode == 'matrixv2':
268
+ lora = (x @ self.lora_shared_A.transpose(0, 1)
269
+ @ self.lora_shared_B.transpose(0, 1)) * self.lora_shared_scale
270
+ lora_tasks = {
271
+ task: pretrained + lora + ((x if x_tasks is None else x_tasks[task]) @ self.lora_tasks_A[task].transpose(
272
+ 0, 1) @ self.lora_tasks_B[task].transpose(0, 1) * self.lora_task_scale[task])
273
+ for task in self.tasks
274
+ } if self.tasks is not None else None
275
+ elif self.shared_mode == 'addition':
276
+ lora_tasks = {
277
+ task: pretrained + ((x if x_tasks is None else x_tasks[task]) @ self.lora_tasks_A[task].transpose(
278
+ 0, 1) @ self.lora_tasks_B[task].transpose(0, 1) * self.lora_task_scale[task])
279
+ for task in self.tasks
280
+ } if self.tasks is not None else None
281
+ lora = self.lora_norm(torch.sum(torch.stack(
282
+ list(lora_tasks.values()), dim=0), dim=0))
283
+
284
+ return pretrained + lora, lora_tasks
285
+
286
+
287
+ class MTLoRAQKV(LoRALayer):
288
+ def __init__(
289
+ self,
290
+ # ↓ this part is for pretrained weights
291
+ in_features: int,
292
+ out_features: int,
293
+ # ↓ the remaining part is for LoRA
294
+ r: Union[int, Mapping[str, int]] = 0,
295
+ lora_shared_scale: float = 1.0,
296
+ lora_task_scale: float = 1.0,
297
+ lora_dropout: float = 0.0,
298
+ tasks=None,
299
+ trainable_scale_shared=False,
300
+ trainable_scale_per_task=False,
301
+ shared_mode: str = 'matrix',
302
+ **kwargs,
303
+ ):
304
+ if isinstance(r, int):
305
+ r = {'shared': r}
306
+ super().__init__(r=r['shared'], lora_alpha=lora_shared_scale, lora_dropout=lora_dropout)
307
+ self.tasks = tasks
308
+ self.q = MTLoRALinear(in_features, out_features, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=lora_task_scale, lora_dropout=lora_dropout,
309
+ tasks=tasks, trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode, **kwargs)
310
+ self.k = MTLoRALinear(in_features, out_features, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=lora_task_scale, lora_dropout=lora_dropout,
311
+ tasks=tasks, trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode, **kwargs)
312
+ self.v = MTLoRALinear(in_features, out_features, r=r, lora_shared_scale=lora_shared_scale, lora_task_scale=lora_task_scale, lora_dropout=lora_dropout,
313
+ tasks=tasks, trainable_scale_shared=trainable_scale_shared, trainable_scale_per_task=trainable_scale_per_task, shared_mode=shared_mode, **kwargs)
314
+
315
+ def reset_parameters(self):
316
+ self.q.reset_parameters()
317
+ self.k.reset_parameters()
318
+ self.v.reset_parameters()
319
+
320
+ def merge(self):
321
+ raise NotImplementedError
322
+
323
+ def forward(self, x: torch.Tensor, x_tasks: Dict[str, torch.Tensor] = None):
324
+ return (torch.cat([self.q(x, x_tasks)[0], self.k(x, x_tasks)[0], self.v(x, x_tasks)[0]], dim=-1),
325
+ {task: torch.cat([self.q(x, x_tasks)[1][task], self.k(x, x_tasks)[1][task], self.v(x, x_tasks)[1][task]], dim=-1) for task in self.tasks} if self.tasks is not None else None)
326
+
327
+
328
+ class LoRAQKVLinear(LoRALinear):
329
+ # LoRA implemented in a dense layer
330
+ def __init__(
331
+ self,
332
+ # ↓ this part is for pretrained weights
333
+ in_features: int,
334
+ out_features: int,
335
+ # ↓ the remaining part is for LoRA
336
+ n_head: int,
337
+ n_query_groups: int,
338
+ r: int = 0,
339
+ lora_alpha: int = 1,
340
+ lora_dropout: float = 0.0,
341
+ enable_lora: Union[bool, Tuple[bool, bool, bool]] = False,
342
+ **kwargs,
343
+ ):
344
+ """LoRA wrapper around linear class that is used for calculation of q, k and v matrices.
345
+
346
+ This class has three weight matrices:
347
+ 1. Pretrained weights are stored as `self.linear.weight`
348
+ 2. LoRA A matrix as `self.lora_A`
349
+ 3. LoRA B matrix as `self.lora_B`
350
+ Only LoRA's A and B matrices are updated, pretrained weights stay frozen.
351
+
352
+ Args:
353
+ in_features: number of input features of the pretrained weights
354
+ out_features: number of output features of the pretrained weights
355
+ n_head: number of attention heads
356
+ n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`)
357
+ r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of
358
+ the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2)
359
+ lora_alpha: alpha is needed for scaling updates as alpha/r
360
+ "This scaling helps to reduce the need to retune hyperparameters when we vary r"
361
+ https://arxiv.org/pdf/2106.09685.pdf (section 4.1)
362
+ lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A)
363
+ enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we
364
+ don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query`
365
+ and `value` but keep `key` without weight updates we should pass `[True, False, True]`
366
+ """
367
+ super(LoRALinear, self).__init__(
368
+ r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
369
+ self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
370
+ self.n_head = n_head
371
+ self.n_query_groups = n_query_groups
372
+ if isinstance(enable_lora, bool):
373
+ enable_lora = [enable_lora] * 3
374
+ assert len(enable_lora) == 3
375
+ self.enable_lora = enable_lora
376
+
377
+ # Actual trainable parameters
378
+ # To better understand initialization let's imagine that we have such parameters:
379
+ # ⚬ in_features: 128 (embeddings_size)
380
+ # ⚬ out_features: 384 (3 * embedding_size)
381
+ # ⚬ r: 2
382
+ # ⚬ enable_lora: [True, False, True]
383
+ if r > 0 and any(enable_lora):
384
+ self.lora_A = nn.Parameter(self.linear.weight.new_zeros(
385
+ (r * sum(enable_lora), in_features))) # (4, 128)
386
+ enable_q, enable_k, enable_v = enable_lora
387
+ self.kv_embd_size = self.linear.in_features // (
388
+ n_head // n_query_groups)
389
+ # qkv_shapes will be used to split a tensor with weights correctly
390
+ qkv_shapes = (
391
+ self.linear.in_features * enable_q,
392
+ self.kv_embd_size * enable_k,
393
+ self.kv_embd_size * enable_v,
394
+ )
395
+ self.qkv_shapes = [s for s in qkv_shapes if s]
396
+ self.lora_B = nn.Parameter(self.linear.weight.new_zeros(
397
+ sum(self.qkv_shapes), r)) # (256, 2))
398
+ # Notes about shapes above
399
+ # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
400
+ # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in
401
+ # F.linear function weights are automatically transposed. In addition conv1d requires channels to
402
+ # be before seq length
403
+ # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is
404
+ # 128*2; 2 tells to have two channels per group for group convolution
405
+
406
+ # Scaling:
407
+ # This balances the pretrained model`s knowledge and the new task-specific adaptation
408
+ # https://lightning.ai/pages/community/tutorial/lora-llm/
409
+ # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set
410
+ # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can
411
+ # tune these values to your needs. This value can be even slightly greater than 1.0!
412
+ # https://github.com/cloneofsimo/lora
413
+ self.scaling = self.lora_alpha / self.r
414
+
415
+ # Compute the indices
416
+ # Indices are needed to properly pad weight updates with zeros. If we want to fine-tune queries and values,
417
+ # but not keys, then the weights update should be:
418
+ #
419
+ # [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
420
+ # [....................................],
421
+ # [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
422
+ # ↑ ↑ ↑
423
+ # ________________________________________
424
+ # | query | key | value |
425
+ # ----------------------------------------
426
+ self.lora_ind = []
427
+ if enable_q:
428
+ self.lora_ind.extend(range(0, self.linear.in_features))
429
+ if enable_k:
430
+ self.lora_ind.extend(
431
+ range(self.linear.in_features, self.linear.in_features + self.kv_embd_size))
432
+ if enable_v:
433
+ self.lora_ind.extend(
434
+ range(self.linear.in_features + self.kv_embd_size, self.linear.out_features))
435
+ self.reset_parameters()
436
+
437
+ def zero_pad(self, x: torch.Tensor) -> torch.Tensor:
438
+ """Properly pad weight updates with zeros.
439
+
440
+ If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys,
441
+ then the weights update should be:
442
+
443
+ [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,],
444
+ [....................................],
445
+ [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]]
446
+ ↑ ↑ ↑
447
+ ________________________________________
448
+ | query | key | value |
449
+ ----------------------------------------
450
+
451
+ Args:
452
+ x: tensor with weights update that will be padded with zeros if necessary
453
+
454
+ Returns:
455
+ A tensor with weight updates and zeros for deselected q, k or v
456
+ """
457
+ # we need to do zero padding only if LoRA is disabled for one of QKV matrices
458
+ if all(self.enable_lora):
459
+ return x
460
+
461
+ # Let's image that:
462
+ # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size)
463
+ # ⚬ embeddings_size: 128
464
+ # ⚬ self.linear.out_features: 384 (3 * embeddings_size)
465
+ # ⚬ enable_lora: [True, False, True]
466
+ # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected
467
+ # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but
468
+ # only for key updates (this is where self.lora_ind comes in handy)
469
+ # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors
470
+ # for example when we want to merge/unmerge LoRA weights and pretrained weights
471
+ x = x.transpose(0, 1)
472
+ result = x.new_zeros(
473
+ (*x.shape[:-1], self.linear.out_features)) # (64, 64, 384)
474
+ result = result.view(-1, self.linear.out_features) # (4096, 384)
475
+ result = result.index_copy(
476
+ 1, torch.tensor(
477
+ self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes))
478
+ ) # (4096, 256)
479
+ # (64, 64, 384)
480
+ return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1)
481
+
482
+ def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
483
+ """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries.
484
+
485
+ If the number of heads is equal to the number of query groups - grouped queries are disabled
486
+ (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized
487
+ query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the
488
+ input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple
489
+ conv layers side by side).
490
+
491
+ Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually,
492
+ apply each part of the weight matrix to the corresponding input's part and concatenate the result.
493
+
494
+ Args:
495
+ input: input matrix of shape (B, C, T)
496
+ weight: weight matrix of shape (C_output, rank, 1).
497
+ "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class).
498
+
499
+ Returns:
500
+ A tensor with a shape (B, C_output, T)
501
+
502
+ """
503
+ if self.n_head == self.n_query_groups:
504
+ # (B, C_output, T)
505
+ return F.conv1d(input, weight, groups=sum(self.enable_lora))
506
+
507
+ # Notation:
508
+ # ⚬ N: number of enabled LoRA layers (self.enable_lora)
509
+ # ⚬ C_output': embeddings size for each LoRA layer (not equal in size)
510
+ # ⚬ r: rank of all LoRA layers (equal in size)
511
+
512
+ input_splitted = input.chunk(
513
+ sum(self.enable_lora), dim=1) # N * (B, C // N, T)
514
+ weight_splitted = weight.split(
515
+ self.qkv_shapes) # N * (C_output', r, 1)
516
+ return torch.cat(
517
+ # (B, C_output', T)
518
+ [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1
519
+ ) # (B, C_output, T)
520
+
521
+ def merge(self):
522
+ """Merges the LoRA weights into the full-rank weights (W = W + delta_W)."""
523
+
524
+ # Let's assume that:
525
+ # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size)
526
+ # ⚬ self.lora_A.data: (4, 128)
527
+ # ⚬ self.lora_B.data: (256, 2)
528
+ if self.r > 0 and any(self.enable_lora) and not self.merged:
529
+ delta_w = self.conv1d(
530
+ self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128)
531
+ self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
532
+ ).squeeze(
533
+ 0
534
+ ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128)
535
+ # W = W + delta_W (merge)
536
+ # (256, 128) after zero_pad (384, 128)
537
+ self.linear.weight.data += self.zero_pad(delta_w * self.scaling)
538
+ self.merged = True
539
+
540
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
541
+ """Do the forward pass.
542
+
543
+ If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication.
544
+ If not, then multiply pretrained weights with input, apply LoRA on input and do summation.
545
+
546
+ Args:
547
+ x: input tensor of shape (batch_size, context_length, embedding_size)
548
+
549
+ Returns:
550
+ Output tensor of shape (batch_size, context_length, 3 * embedding_size)
551
+ """
552
+
553
+ # Let's assume that:
554
+ # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size)
555
+ # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size)
556
+ # ⚬ self.lora_A.data: (4, 128)
557
+ # ⚬ self.lora_B.data: (256, 2)
558
+
559
+ # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass;
560
+ # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights
561
+ pretrained = self.linear(x)
562
+ if self.r == 0 or not any(self.enable_lora) or self.merged:
563
+ return pretrained
564
+ # (64, 64, 128) @ (4, 128) -> (64, 64, 4)
565
+ after_A = F.linear(self.lora_dropout(x), self.lora_A)
566
+ # For F.conv1d:
567
+ # ⚬ input: input tensor of shape (mini-batch, in_channels, iW)
568
+ # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW)
569
+ after_B = self.conv1d(
570
+ after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64)
571
+ self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1)
572
+ ).transpose(
573
+ -2, -1
574
+ ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256)
575
+ # (64, 64, 256) after zero_pad (64, 64, 384)
576
+ lora = self.zero_pad(after_B) * self.scaling
577
+ return pretrained + lora
578
+
579
+
580
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none", freeze_patch_embed: bool = False, freeze_norm: bool = False, free_relative_bias: bool = False, freeze_downsample_reduction=False) -> None:
581
+ """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights.
582
+
583
+ Args:
584
+ model: model with LoRA layers
585
+ bias:
586
+ ``"none"``: all bias weights will be frozen,
587
+ ``"lora_only"``: only bias weight for LoRA layers will be unfrozen,
588
+ ``"all"``: all bias weights will be unfrozen.
589
+
590
+ Raises:
591
+ NotImplementedError: if `bias` not in ["none", "lora_only", "all"]
592
+ """
593
+ def lora_filter(key): return "lora_" in key
594
+ def patch_embed_filter(
595
+ key): return not freeze_patch_embed and "patch_embed" in key
596
+
597
+ def norm_filter(key): return not freeze_norm and "norm" in key
598
+
599
+ def downsample_reduction_filter(
600
+ key): return not freeze_downsample_reduction and "downsample.reduction" in key
601
+
602
+ def relative_position_bias_filter(
603
+ key): return not free_relative_bias and "relative_position_bias_table" in key
604
+
605
+ def all_filters(key):
606
+ return lora_filter(key) or patch_embed_filter(key) or norm_filter(key) or downsample_reduction_filter(key) or relative_position_bias_filter(key)
607
+
608
+ print(f"LoRA bias mode: {bias}")
609
+ print(f"LoRA Freeze patch_embed: {freeze_patch_embed}")
610
+ print(f"LoRA Freeze norm: {freeze_norm}")
611
+ print(f"LoRA Freeze downsample_reduction: {freeze_downsample_reduction}")
612
+ print(f"LoRA Freeze relative_position_bias: {free_relative_bias}")
613
+ # freeze all layers except LoRA's
614
+ for n, p in model.named_parameters():
615
+ if not all_filters(n):
616
+ p.requires_grad = False
617
+
618
+ # depending on the `bias` value unfreeze bias weights
619
+ if bias == "none":
620
+ return
621
+ if bias == "all":
622
+ for n, p in model.named_parameters():
623
+ if "bias" in n:
624
+ p.requires_grad = True
625
+ elif bias == "lora_only":
626
+ for m in model.modules():
627
+ if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None:
628
+ m.bias.requires_grad = True
629
+ else:
630
+ raise NotImplementedError
631
+
632
+
633
+ def lora_filter(key: str, value: Any) -> bool:
634
+ return "lora_" in key
635
+
636
+
637
+ def merge_lora_weights(model) -> None:
638
+ """Merge LoRA weights into the full-rank weights to speed up inference."""
639
+ for module in model.modules():
640
+ if isinstance(module, LoRALinear):
641
+ module.merge()
642
+
643
+
644
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str, split_qkv: bool = False) -> Dict:
645
+ unmatched_keys = []
646
+ for checkpoint_name, attribute_name in mapping.items():
647
+ full_checkpoint_name = prefix + checkpoint_name
648
+ if full_checkpoint_name in state_dict:
649
+ full_attribute_name = prefix + attribute_name
650
+ weights = state_dict.pop(
651
+ full_checkpoint_name)
652
+ last_four = ".".join(full_attribute_name.split(".")[-4:])
653
+ if split_qkv and last_four in ["attn.qkv.linear.weight", "attn.qkv.linear.bias"]:
654
+ w_q, w_k, w_v = torch.chunk(weights, chunks=3)
655
+ weight_bias = last_four.split(".")[-1]
656
+ full_attribute_name_without_suffix = ".".join(full_attribute_name.split(".")[
657
+ :-2])
658
+ state_dict[f"{full_attribute_name_without_suffix}.q.linear.{weight_bias}"] = w_q
659
+ state_dict[f"{full_attribute_name_without_suffix}.k.linear.{weight_bias}"] = w_k
660
+ state_dict[f"{full_attribute_name_without_suffix}.v.linear.{weight_bias}"] = w_v
661
+ else:
662
+ state_dict[full_attribute_name] = weights
663
+ else:
664
+ unmatched_keys.append(checkpoint_name)
665
+ if len(unmatched_keys) > 0:
666
+ print(
667
+ f"WARNING: The following keys from the checkpoint were not mapped: {unmatched_keys}")
668
+ return state_dict