# -*- coding:utf-8 -*-

"""



Author:

    Weichen Shen,weichenswc@163.com



Reference:

    [1] Perozzi B, Al-Rfou R, Skiena S. Deepwalk: Online learning of social representations[C]//Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining. ACM, 2014: 701-710.(http://www.perozzi.net/publications/14_kdd_deepwalk.pdf)



"""
from gensim.models import Word2Vec
import sys
sys.path.append('../ge')
from walker import RandomWalker


class DeepWalk:
    def __init__(self, graph, walk_length, num_walks, workers=1):
        """
        初始化 DeepWalk 模型
        :param graph: nx.Graph, 输入的图
        :param walk_length: int, 随机游走的长度
        :param num_walks: int, 每个节点的随机游走次数
        :param workers: int, 并行处理的工作线程数
        """
        self.graph = graph
        self.w2v_model = None
        self._embeddings = {}

        self.walker = RandomWalker(
            graph, p=1, q=1, )
        self.sentences = self.walker.simulate_walks(
            num_walks=num_walks, walk_length=walk_length, workers=workers, verbose=1)

    def train(self, embed_size=128, window_size=5, workers=3, iter=5, **kwargs):
        """
        使用 Word2Vec 训练模型
        :param embed_size: int, 嵌入向量的维度
        :param window_size: int, Word2Vec 的上下文窗口大小
        :param workers: int, 并行处理的工作线程数
        :param iter: int, 训练的迭代次数
        :param kwargs: dict, 其他 Word2Vec 参数
        :return: Word2Vec model
        """

        kwargs["sentences"] = self.sentences
        kwargs["min_count"] = kwargs.get("min_count", 0)
        kwargs["vector_size"] = embed_size
        kwargs["sg"] = 1  # skip gram
        kwargs["hs"] = 1  # deepwalk use Hierarchical Softmax
        kwargs["workers"] = workers
        kwargs["window"] = window_size
        kwargs["epochs"] = iter

        print("Learning embedding vectors...")
        model = Word2Vec(**kwargs)
        print("Learning embedding vectors done!")

        self.w2v_model = model
        return model

    def get_embeddings(self, ):
        """
        获取学习到的节点嵌入
        :return: dict, 节点到其嵌入向量的映射
        """
        if self.w2v_model is None:
            print("model not train")
            return {}

        self._embeddings = {}
        for word in self.graph.nodes():
            self._embeddings[word] = self.w2v_model.wv[word]

        return self._embeddings
