| """Task class definition """ | |
| from dataclasses import dataclass | |
| from typing import List, Type | |
| import torch.nn as nn | |
| class Task: | |
| """Encapsulates all configuration for a single task.""" | |
| name: str | |
| class_labels: List[str] | |
| criterion: Type[nn.Module] | |
| weight: float = 1.0 | |
| use_weighted_loss: bool = False | |
| def num_classes(self) -> int: | |
| return len(self.class_labels) | |