Skip to content

Commit

Permalink
Added dect.
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Nov 18, 2024
1 parent 624056c commit 1193be6
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "TopoModelX"]
path = dependencies/TopoModelX
url = [email protected]:pyt-team/TopoModelX.git
[submodule "dependencies/dect"]
path = dependencies/dect
url = [email protected]:aidos-lab/dect.git
152 changes: 152 additions & 0 deletions code/models/DECT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch.nn as nn
from torch_geometric.nn import pool
from torch_geometric.data import Data

from pydantic import BaseModel

from .model_types import ModelType
from dataclasses import dataclass
import torch


class DECTConfig(BaseModel):
type: ModelType = ModelType.DECT
num_hidden_neurons: int = 64
num_hidden_layers: int = 3
num_node_features: int = 1
out_channels: int = 5


@dataclass(frozen=True)
class EctConfig:
"""
Config for initializing an ect layer.
"""

num_thetas: int = 32
bump_steps: int = 32
r: float = 1.1
ect_type: str = "points"
normalized: bool = True


def generate_uniform_directions(num_thetas: int = 64, d: int = 3):
"""
Generate randomly sampled directions from a sphere in d dimensions.
First a standard gaussian centered at 0 with standard deviation 1 is sampled
and then projected onto the unit sphere. This yields a uniformly sampled set
of points on the unit spere. Please note that the generated shapes with have
shape [d, num_thetas].
Parameters
----------
num_thetas: int
The number of directions to generate.
d: int
The dimension of the unit sphere. Default is 3 (hence R^3)
"""
v = torch.randn(size=(d, num_thetas))
v /= v.pow(2).sum(axis=0).sqrt().unsqueeze(1).T
return v


def compute_ecc(nh, index, lin, out, scale):
"""
Computes the ECC of a set of points given the node heights.
"""
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
return torch.index_add(out, 1, index, ecc).movedim(0, 1)


def compute_ect_points(data, index, v, lin, out, scale):
"""Compute the ECT of a set of points."""
nh = data.x @ v
return compute_ecc(nh, index, lin, out, scale)


class EctLayer(nn.Module):
"""Docstring for EctLayer."""

def __init__(self, config: EctConfig, v=None):
super().__init__()
self.config = config
self.lin = nn.Parameter(
(
torch.linspace(-config.r, config.r, config.bump_steps).view(
-1, 1, 1
)
),
requires_grad=False,
)
if v is None:
raise AttributeError("Please provide the directions")
self.v = nn.Parameter(v, requires_grad=False)

if config.ect_type == "points":
self.compute_ect = compute_ect_points

def forward(self, data: Data, index, scale=500):
"""Forward method"""
out = torch.zeros(
size=(
self.config.bump_steps,
index.max().item() + 1,
self.config.num_thetas,
),
device=data.x.device,
)
ect = self.compute_ect(data, index, self.v, self.lin, out, scale)
if self.config.normalized:
return ect / torch.amax(ect, dim=(1, 2)).unsqueeze(1).unsqueeze(1)
return ect


class DECTMLP(nn.Module):
def __init__(
self,
config: DECTConfig,
):
super().__init__()

self.ectconfig = EctConfig()
v = generate_uniform_directions(
self.ectconfig.num_thetas, d=config.num_node_features
)
self.layer = EctLayer(self.ectconfig, v=v)

self.input_layer = nn.Linear(
self.ectconfig.num_thetas * self.ectconfig.bump_steps,
config.num_hidden_neurons,
)
self.hidden_layers = nn.ModuleList(
[
nn.Linear(config.num_hidden_neurons, config.num_hidden_neurons)
for _ in range(config.num_hidden_layers)
]
)
self.output_layer = nn.Linear(
config.num_hidden_neurons, config.out_channels
)

def forward(self, batch):
ect = self.layer(batch, batch.batch)
x = self.input_layer(ect.flatten(start_dim=1))
x = nn.functional.relu(x)
for hidden_layer in self.hidden_layers:
x = hidden_layer(x)
x = nn.functional.relu(x)
x = self.output_layer(x)
return x


if __name__ == "__main__":
from torch_geometric.data import Batch, Data

batch = Batch.from_data_list(
[Data(x=torch.rand(size=(10, 8))), Data(x=torch.rand(size=(12, 8)))]
)

config = DECTConfig(num_node_features=8)
model = DECTMLP(config=config)
res = model(batch)
1 change: 1 addition & 0 deletions code/models/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


class ModelType(Enum):
DECT = "dect"
SAN = "san"
SCCN = "sccn"
SCCNN = "sccnn"
Expand Down
5 changes: 5 additions & 0 deletions code/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from models.GCN import GCN, GCNConfig
from models.GAT import GAT, GATConfig
from models.MLP import MLP, MLPConfig
from models.DECT import DECTMLP, DECTConfig
from models.TransfConv import TransfConv, TransfConvConfig
from models.TAG import TAG, TAGConfig
from .model_types import ModelType
Expand All @@ -22,6 +23,7 @@
from models.simplicial_complexes.scn import SCN, SCNConfig

model_lookup: Dict[ModelType, nn.Module] = {
ModelType.DECT: DECTMLP,
ModelType.GAT: GAT,
ModelType.GCN: GCN,
ModelType.MLP: MLP,
Expand All @@ -34,6 +36,7 @@
}

ModelConfig = Union[
Annotated[DECTConfig, Tag(ModelType.DECT)],
Annotated[MLPConfig, Tag(ModelType.MLP)],
Annotated[GATConfig, Tag(ModelType.GAT)],
Annotated[GCNConfig, Tag(ModelType.GCN)],
Expand All @@ -46,6 +49,7 @@
]

model_cfg_lookup: Dict[ModelType, ModelConfig] = {
ModelType.DECT: DECTConfig,
ModelType.GAT: GATConfig,
ModelType.GCN: GCNConfig,
ModelType.MLP: MLPConfig,
Expand All @@ -58,6 +62,7 @@
}

dataloader_lookup: Dict[ModelType, Callable] = {
ModelType.DECT: DataLoader,
ModelType.GAT: DataLoader,
ModelType.GCN: DataLoader,
ModelType.MLP: DataLoader,
Expand Down
1 change: 1 addition & 0 deletions dependencies/dect
Submodule dect added at da55c1

0 comments on commit 1193be6

Please sign in to comment.