"""GraphWave class implementation."""

import pygsp
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import networkx as nx
from pydoc import locate

class WaveletMachine:
    """
    "Learning Structural Node Embeddings Via Diffusion Wavelets" 的一种实现。
    这个类封装了通过扩散小波学习节点结构嵌入的核心逻辑。
    """
    def __init__(self, G, settings):
        """
        初始化方法。
        :param G: 输入的 networkx 图对象。
        :param settings: 包含设置的 argparse 对象。
        """
        # 记录图中的节点，作为嵌入结果的索引
        self.index = G.nodes()
        # 将 networkx 图转换为 pygsp 图对象，便于进行谱图分析
        self.G = pygsp.graphs.Graph(nx.adjacency_matrix(G))
        self.number_of_nodes = len(nx.nodes(G))
        self.settings = settings
        # 如果节点数超过阈值，则自动切换到近似计算方法以提高效率
        if self.number_of_nodes > self.settings.switch:
            self.settings.mechanism = "approximate"

        # 根据采样点数和步长，生成一系列的时间/尺度参数
        self.steps = [x*self.settings.step_size for x in range(self.settings.sample_number)]

    def single_wavelet_generator(self, node):
        """
        使用特征分解，为给定节点计算特征函数。
        :param node: 正在被嵌入的节点。
        :return: 小波系数。
        """
        # 创建一个脉冲信号，只在当前节点处为1
        impulse = np.zeros((self.number_of_nodes))
        impulse[node] = 1.0
        # 计算热核，这是扩散小波的核心
        diags = np.diag(np.exp(-self.settings.heat_coefficient*self.eigen_values))
        eigen_diag = np.dot(self.eigen_vectors, diags)
        waves = np.dot(eigen_diag, np.transpose(self.eigen_vectors))
        # 计算小波系数
        wavelet_coefficients = np.dot(waves, impulse)
        return wavelet_coefficients

    def exact_wavelet_calculator(self):
        """
        使用精确的特征值分解来计算结构角色嵌入。
        """
        self.real_and_imaginary = []
        # 遍历图中每个节点
        for node in tqdm(range(self.number_of_nodes)):
            # 为当前节点生成小波信号
            wave = self.single_wavelet_generator(node)
            # 通过计算特征函数的期望值来获得小波系数（傅里叶变换）
            wavelet_coefficients = [np.mean(np.exp(wave*1.0*step*1j)) for step in self.steps]
            self.real_and_imaginary.append(wavelet_coefficients)
        # 将结果转换为 numpy 数组
        self.real_and_imaginary = np.array(self.real_and_imaginary)

    def exact_structural_wavelet_embedding(self):
        """
        计算特征向量、特征值，并创建精确的嵌入。
        """
        # 计算图的傅里叶基（特征向量和特征值）
        self.G.compute_fourier_basis()
        # 对特征值进行归一化
        self.eigen_values = self.G.e / max(self.G.e)
        self.eigen_vectors = self.G.U
        # 计算精确的小波嵌入
        self.exact_wavelet_calculator()

    def approximate_wavelet_calculator(self):
        """
        给定切比雪夫多项式，计算近似嵌入。
        """
        self.real_and_imaginary = []
        # 遍历图中每个节点
        for node in tqdm(range(self.number_of_nodes)):
            # 创建脉冲信号
            impulse = np.zeros((self.number_of_nodes))
            impulse[node] = 1
            # 使用切比雪夫多项式近似计算小波变换
            wave_coeffs = pygsp.filters.approximations.cheby_op(self.G, self.chebyshev, impulse)
            # 计算特征函数的期望值
            real_imag = [np.mean(np.exp(wave_coeffs*1*step*1j)) for step in self.steps]
            self.real_and_imaginary.append(real_imag)
        self.real_and_imaginary = np.array(self.real_and_imaginary)

    def approximate_structural_wavelet_embedding(self):
        """
        估计最大特征值。
        设置热滤波器和切比雪夫多项式。
        使用近似小波计算器方法。
        """
        # 估计图拉普拉斯算子的最大特征值
        self.G.estimate_lmax()
        # 定义热核滤波器
        self.heat_filter = pygsp.filters.Heat(self.G, tau=[self.settings.heat_coefficient])
        # 计算用于近似滤波器的切比雪夫系数
        self.chebyshev = pygsp.filters.approximations.compute_cheby_coeff(self.heat_filter,
                                                                          m=self.settings.approximation)
        # 计算近似的小波嵌入
        self.approximate_wavelet_calculator()

    def create_embedding(self):
        """
        根据 mechanism 的设置，创建精确或近似的嵌入。
        """
        if self.settings.mechanism == "exact":
            self.exact_structural_wavelet_embedding()
        else:
            self.approximate_structural_wavelet_embedding()

    def transform_and_save_embedding(self):
        """
        转换包含实部和虚部值的 numpy 数组。
        创建一个 pandas 数据帧并将其保存为 csv。
        """
        print("\nSaving the embedding.")
        # 将复数嵌入分解为实部和虚部
        features = [self.real_and_imaginary.real, self.real_and_imaginary.imag]
        self.real_and_imaginary = np.concatenate(features, axis=1)
        # 创建列名
        columns_1 = ["reals_"+str(x) for x in range(self.settings.sample_number)]
        columns_2 = ["imags_"+str(x) for x in range(self.settings.sample_number)]
        columns = columns_1 + columns_2
        # 创建 DataFrame
        self.real_and_imaginary = pd.DataFrame(self.real_and_imaginary, columns=columns)
        self.real_and_imaginary.index = self.index
        # 根据指定的节点标签类型对索引进行排序
        self.real_and_imaginary.index = self.real_and_imaginary.index.astype(locate(self.settings.node_label_type))
        self.real_and_imaginary = self.real_and_imaginary.sort_index()
        # 保存到 CSV 文件
        self.real_and_imaginary.to_csv(self.settings.output)
