benjamin-paine commited on
Commit
f2679a8
·
verified ·
1 Parent(s): 1eba3ec

Create flux2_tiny_autoencoder.py

Browse files
Files changed (1) hide show
  1. flux2_tiny_autoencoder.py +105 -0
flux2_tiny_autoencoder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Apache License
2
+ #
3
+ # Copyright 2025 fal - features and labels, inc.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from diffusers.models import AutoencoderTiny
21
+ from diffusers.models.autoencoders.vae import EncoderOutput, DecoderOutput
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+
24
+ from flashpack.integrations.diffusers import FlashPackDiffusersModelMixin
25
+
26
+ class Flux2TinyAutoEncoder(FlashPackDiffusersModelMixin, ConfigMixin):
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ in_channels: int = 3,
31
+ out_channels: int = 3,
32
+ latent_channels: int = 128,
33
+ encoder_block_out_channels: list[int] = [64, 64, 64, 64],
34
+ decoder_block_out_channels: list[int] = [64, 64, 64, 64],
35
+ act_fn: str = "silu",
36
+ upsampling_scaling_factor: int = 2,
37
+ num_encoder_blocks: list[int] = [1, 3, 3, 3],
38
+ num_decoder_blocks: list[int] = [3, 3, 3, 1],
39
+ latent_magnitude: float = 3.0,
40
+ latent_shift: float = 0.5,
41
+ force_upcast: bool = False,
42
+ scaling_factor: float = 0.13025,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.tiny_vae = AutoencoderTiny(
46
+ in_channels=in_channels,
47
+ out_channels=out_channels,
48
+ encoder_block_out_channels=encoder_block_out_channels,
49
+ decoder_block_out_channels=decoder_block_out_channels,
50
+ act_fn=act_fn,
51
+ latent_channels=latent_channels // 4,
52
+ upsampling_scaling_factor=upsampling_scaling_factor,
53
+ num_encoder_blocks=num_encoder_blocks,
54
+ num_decoder_blocks=num_decoder_blocks,
55
+ latent_magnitude=latent_magnitude,
56
+ latent_shift=latent_shift,
57
+ force_upcast=force_upcast,
58
+ scaling_factor=scaling_factor,
59
+ )
60
+ self.extra_encoder = nn.Conv2d(
61
+ latent_channels // 4, latent_channels,
62
+ kernel_size=4, stride=2, padding=1
63
+ )
64
+ self.extra_decoder = nn.ConvTranspose2d(
65
+ latent_channels, latent_channels // 4,
66
+ kernel_size=4, stride=2, padding=1
67
+ )
68
+ self.residual_encoder = nn.Sequential(
69
+ nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1),
70
+ nn.GroupNorm(8, latent_channels),
71
+ nn.SiLU(),
72
+ nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1),
73
+ )
74
+ self.residual_decoder = nn.Sequential(
75
+ nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1),
76
+ nn.GroupNorm(8, latent_channels // 4),
77
+ nn.SiLU(),
78
+ nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1),
79
+ )
80
+
81
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput:
82
+ encoded = self.tiny_vae.encode(x, return_dict=False)[0]
83
+ compressed = self.extra_encoder(encoded)
84
+ enhanced = self.residual_encoder(compressed) + compressed
85
+
86
+ if return_dict:
87
+ return EncoderOutput(latent=enhanced)
88
+ return enhanced
89
+
90
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput:
91
+ decompressed = self.extra_decoder(z)
92
+ enhanced = self.residual_decoder(decompressed) + decompressed
93
+ decoded = self.tiny_vae.decode(enhanced, return_dict=False)[0]
94
+
95
+ if return_dict:
96
+ return DecoderOutput(sample=decoded)
97
+ return decoded
98
+
99
+ def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput:
100
+ encoded = self.encode(sample, return_dict=False)[0]
101
+ decoded = self.decode(encoded, return_dict=False)[0]
102
+
103
+ if return_dict:
104
+ return DecoderOutput(sample=decoded)
105
+ return decoded