|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from typing import Callable, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InitArgs: |
|
|
use_gaussian: bool = True |
|
|
coeff_std: Optional[float] = None |
|
|
no_init: bool = False |
|
|
|
|
|
|
|
|
def get_init_fn( |
|
|
args: InitArgs, input_dim: int, init_depth: Optional[int] |
|
|
) -> Callable[[torch.Tensor], torch.Tensor]: |
|
|
""" |
|
|
Init functions. |
|
|
""" |
|
|
if args.no_init: |
|
|
return lambda x: x |
|
|
|
|
|
|
|
|
std = 1 / math.sqrt(input_dim) |
|
|
std = std if args.coeff_std is None else (args.coeff_std * std) |
|
|
|
|
|
|
|
|
if init_depth is not None: |
|
|
std = std / math.sqrt(2 * init_depth) |
|
|
|
|
|
|
|
|
if args.use_gaussian: |
|
|
return partial( |
|
|
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std |
|
|
) |
|
|
else: |
|
|
bound = math.sqrt(3) * std |
|
|
return partial(torch.nn.init.uniform_, a=-bound, b=bound) |
|
|
|