Spaces:
Sleeping
Sleeping
Commit
·
85f8c1c
1
Parent(s):
e75805e
Update model.py
Browse files
model.py
CHANGED
|
@@ -8,26 +8,21 @@ from torchvision.models import resnet18, resnet50
|
|
| 8 |
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
| 9 |
|
| 10 |
class DistMult(nn.Module):
|
| 11 |
-
def __init__(self,
|
| 12 |
super(DistMult, self).__init__()
|
| 13 |
-
self.args = args
|
| 14 |
self.num_ent_uid = num_ent_uid
|
| 15 |
|
| 16 |
self.num_relations = 4
|
| 17 |
|
| 18 |
-
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid,
|
| 19 |
-
self.rel_embedding = torch.nn.Embedding(self.num_relations,
|
| 20 |
|
| 21 |
-
self.location_embedding = MLP(
|
| 22 |
|
| 23 |
-
self.time_embedding = MLP(
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
self.image_embedding.fc = nn.Linear(2048, args.embedding_dim)
|
| 28 |
-
else:
|
| 29 |
-
self.image_embedding = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
|
| 30 |
-
self.image_embedding.fc = nn.Linear(512, args.embedding_dim)
|
| 31 |
|
| 32 |
self.target_list = target_list
|
| 33 |
|
|
@@ -36,7 +31,6 @@ class DistMult(nn.Module):
|
|
| 36 |
if all_timestamps is not None:
|
| 37 |
self.all_timestamps = all_timestamps.to(device)
|
| 38 |
|
| 39 |
-
self.args = args
|
| 40 |
self.device = device
|
| 41 |
|
| 42 |
self.init()
|
|
|
|
| 8 |
from torchvision.models import ResNet18_Weights, ResNet50_Weights
|
| 9 |
|
| 10 |
class DistMult(nn.Module):
|
| 11 |
+
def __init__(self, num_ent_uid, target_list, device, all_locs=None, num_habitat=None, all_timestamps=None):
|
| 12 |
super(DistMult, self).__init__()
|
|
|
|
| 13 |
self.num_ent_uid = num_ent_uid
|
| 14 |
|
| 15 |
self.num_relations = 4
|
| 16 |
|
| 17 |
+
self.ent_embedding = torch.nn.Embedding(self.num_ent_uid, 512, sparse=False)
|
| 18 |
+
self.rel_embedding = torch.nn.Embedding(self.num_relations, 512, sparse=False)
|
| 19 |
|
| 20 |
+
self.location_embedding = MLP(2, 512, 3)
|
| 21 |
|
| 22 |
+
self.time_embedding = MLP(1, 512, 3)
|
| 23 |
|
| 24 |
+
self.image_embedding = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
| 25 |
+
self.image_embedding.fc = nn.Linear(2048, 512)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
self.target_list = target_list
|
| 28 |
|
|
|
|
| 31 |
if all_timestamps is not None:
|
| 32 |
self.all_timestamps = all_timestamps.to(device)
|
| 33 |
|
|
|
|
| 34 |
self.device = device
|
| 35 |
|
| 36 |
self.init()
|