Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| import ast | |
| """ | |
| TRAIN FUNCTION DEFINITION: | |
| train(model: StableDiffusionPipeline, | |
| projection_matrices: list[size=L](nn.Module), | |
| og_matrices: list[size=L](nn.Module), | |
| contexts: list[size=N](torch.tensor[size=MAX_LEN,...]), | |
| valuess: list[size=N](list[size=L](torch.tensor[size=MAX_LEN,...])), | |
| old_texts: list[size=N](str), | |
| new_texts: list[size=N](str), | |
| **kwargs) | |
| where L is the number of matrices to edit, and N is the number of sentences to train on (batch size). | |
| PARAMS: | |
| model: the model to use. | |
| projection_matrices: list of projection matrices to edit from the model. | |
| og_matrices: list of original values for the projection matrices. detached from the model. | |
| contexts: list of context vectors (inputs to the matrices) to edit. | |
| valuess: list of results from all matrices for each context vector. | |
| old_texts: list of sentences to be edited. | |
| new_texts: list of target sentences to be aimed at. | |
| **kwargs: additional command line arguments. | |
| TRAIN_FUNC_DICT defined at the bottom of the file. | |
| """ | |
| def baseline_train(model, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts): | |
| return None | |
| def train_closed_form(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, | |
| new_texts, layers_to_edit=None, lamb=0.1): | |
| layers_to_edit = ast.literal_eval(layers_to_edit) if type(layers_to_edit) == str else layers_to_edit | |
| lamb = ast.literal_eval(lamb) if type(lamb) == str else lamb | |
| for layer_num in range(len(projection_matrices)): | |
| if (layers_to_edit is not None) and (layer_num not in layers_to_edit): | |
| continue | |
| with torch.no_grad(): | |
| #mat1 = \lambda W + \sum{v k^T} | |
| mat1 = lamb * projection_matrices[layer_num].weight | |
| #mat2 = \lambda I + \sum{k k^T} | |
| mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device = projection_matrices[layer_num].weight.device) | |
| #aggregate sums for mat1, mat2 | |
| for context, values in zip(contexts, valuess): | |
| context_vector = context.reshape(context.shape[0], context.shape[1], 1) | |
| context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) | |
| value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) | |
| for_mat1 = (value_vector @ context_vector_T).sum(dim=0) | |
| for_mat2 = (context_vector @ context_vector_T).sum(dim=0) | |
| mat1 += for_mat1 | |
| mat2 += for_mat2 | |
| #update projection matrix | |
| projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) | |
| TRAIN_FUNC_DICT = { | |
| "baseline": baseline_train, | |
| "train_closed_form": train_closed_form, | |
| } |