diff --git a/code/get_adj.py b/code/get_adj.py index e4e9c3e..972eb08 100644 --- a/code/get_adj.py +++ b/code/get_adj.py @@ -1,3 +1,4 @@ +# from https://github.com/flyingtango/DiGCN/blob/main/code/get_adj.py import os.path as osp import numpy as np import scipy.sparse as sp @@ -96,7 +97,7 @@ def get_pr_directed_adj(alpha, edge_index, num_nodes, dtype, edge_weight = None) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] def get_appr_directed_adj(alpha, edge_index, num_nodes, dtype, edge_weight=None): - if edge_weight ==None: + if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) fill_value = 1 @@ -154,13 +155,13 @@ def get_appr_directed_adj(alpha, edge_index, num_nodes, dtype, edge_weight=None) deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 - + return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] -def get_second_directed_adj(edge_index, num_nodes, dtype): - - edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, - device=edge_index.device) +def get_second_directed_adj(edge_index, num_nodes, dtype, edge_weight=None): + if edge_weight is None: + edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, + device=edge_index.device) fill_value = 1 edge_index, edge_weight = add_self_loops( edge_index, edge_weight, fill_value, num_nodes)