Skip to content

Commit

Permalink
fallback to lazy import
Browse files Browse the repository at this point in the history
  • Loading branch information
dcolinmorgan committed Dec 5, 2023
1 parent 58d9810 commit d02d480
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
28 changes: 21 additions & 7 deletions graphistry/embed_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import logging, tqdm
import logging
import numpy as np
import pandas as pd
from typing import Optional, Union, Callable, List, TYPE_CHECKING, Any, Tuple
Expand All @@ -7,6 +7,20 @@
from .dep_manager import deps


def lazy_embed_import_dep():
try:
import torch
import torch.nn as nn
import dgl
from dgl.dataloading import GraphDataLoader
import torch.nn.functional as F
from .networks import HeteroEmbed
from tqdm import trange
return True, torch, nn, dgl, GraphDataLoader, HeteroEmbed, F, trange

except:
return False, None, None, None, None, None, None, None

if TYPE_CHECKING:
torch = deps.torch
TT = torch.Tensor
Expand Down Expand Up @@ -168,18 +182,18 @@ def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, devic
return model, g_dataloader

def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_size:int, num_steps:int, device) -> Plottable:
torch = deps.torch
from torch import nn
# torch = deps.torch
# from torch import nn
# from tqdm import trange
_, torch, nn, _, _, _, _, trange = lazy_embed_import_dep()
log('Training embedding')
model, g_dataloader = res._init_model(res, batch_size, sample_size, num_steps, device)
if hasattr(res, "_embed_model") and not res._build_new_embedding_model:
model = res._embed_model
log("--Reusing previous model")

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# from tqdm import tqdm
pbar = tqdm.tqdm(range(epochs), desc=None) # type: ignore
pbar = trange(epochs, desc=None)
model.to(device)

score = 0
Expand All @@ -200,7 +214,7 @@ def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_siz
optimizer.step()
pbar.set_description(
f"epoch: {epoch+1}, loss: {loss.item():.4f}, score: {100*score:.4f}%"
) # type:ignore
)

model.eval()
res._kg_embeddings = model(res._kg_dgl.to(device)).detach()
Expand All @@ -209,7 +223,7 @@ def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_siz
score = res._eval(threshold=0.5)
pbar.set_description(
f"epoch: {epoch+1}, loss: {loss.item():.4f}, score: {100*score:.2f}%"
) # type:ignore
)

return res

Expand Down
6 changes: 3 additions & 3 deletions graphistry/feature_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def remove_internal_namespace_if_present(df: pd.DataFrame):
]
if (len(df.columns) <= 2):
df = df.rename(columns={c: c + '_1' for c in df.columns if c in reserved_namespace})
# if (isinstance(df.columns.to_list()[0],int)):
# int_namespace = pd.to_numeric(df.columns, errors = 'ignore').dropna().to_list() # type: ignore
# df = df.rename(columns={c: str(c) + '_1' for c in df.columns if c in int_namespace})
if (isinstance(df.columns.to_list()[0],int)):
int_namespace = pd.to_numeric(df.columns, errors = 'ignore').dropna().to_list() # type: ignore
df = df.rename(columns={c: str(c) + '_1' for c in df.columns if c in int_namespace})
else:
df = df.drop(columns=reserved_namespace, errors="ignore") # type: ignore
return df
Expand Down

0 comments on commit d02d480

Please sign in to comment.