#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:14
# @Author  : 程婷婷
# @FileName: XgboostClassifyProcess.py
# @Software: PyCharm
import pandas as pd
import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from bert_serving.client import BertClient
import joblib
from sklearn.utils import class_weight
from model.base.views.data.BaseDataProcess import BaseDataProcess
from model.classify.views.textcnn_classify.data.TextcnnClassifyDataLoader import TextcnnClassifyDataLoader


class TextcnnClassifyProcess(BaseDataProcess):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.tcdl = TextcnnClassifyDataLoader(config_path)

    def tokenier(self, data, label):
        tokenizer = Tokenizer()  # 创建一个Tokenizer对象
        tokenizer.fit_on_texts(data)  # 编号，编号是根据词频的
        self.vocab = tokenizer.word_index  # 得到每个词的编号
        df = pd.DataFrame(columns=['content', 'label'])
        df['content'] = data
        df['label'] = label
        train_set, test_set = self.split_dataset(df, use_dev=self.process_config['use_dev'])
        x_train_word_ids = tokenizer.texts_to_sequences(train_set['content'])
        x_test_word_ids = tokenizer.texts_to_sequences(test_set['content'])  # 序列模式
        x_train_padded_seqs = pad_sequences(x_train_word_ids, maxlen=3500)
        x_test_padded_seqs = pad_sequences(x_test_word_ids, maxlen=3500)
        # with open(tokenizer_path, 'wb') as file:
        #     pickle.dump(tokenizer, file, protocol=pickle.HIGHEST_PROTOCOL)
        joblib.dump(tokenizer, filename=self.embedding_config['tokenizer_path'])
        return x_train_padded_seqs, train_set['label'], x_test_padded_seqs, test_set['label']

    def get_embeddingMatrix(self, vocab):
        # 初始化存储所有向量的大矩阵，留意其中多一位（首行），词向量全为 0，用于 padding补零。
        embedding_matrix = np.zeros((len(vocab) + 1, 768))
        bert_client = BertClient(port=5558, port_out=5559)
        for word, i in vocab.items():
            try:
                # print(word)
                embedding_vector = bert_client.encode(word.split(' '))
                if embedding_vector.shape == (1, 768):
                    embedding_vector = embedding_vector.mean(axis=0)
                    embedding_matrix[i] = embedding_vector
                else:
                    print(embedding_vector.shape)
                    print('----------类型错误----------')
            except KeyError:
                continue
        return embedding_matrix

    def class_weight(self, y_train):
        weight = class_weight.compute_class_weight('balanced', np.unique(y_train), y_train)
        classes_weight = dict(enumerate(weight))
        return classes_weight

    def runner_process(self, logger):
        df = self.tcdl.read_file()
        all_label = list(set(df['label']))
        self.label_mapping = {v: k for k, v in dict(enumerate(all_label)).items()}
        labels = df['label'].map(self.label_mapping)
        processed_data = self.process(df['content'], min_content=self.process_config['min_content'])
        print(processed_data)
        x_train_padded_seqs, train_label, x_test_padded_seqs, test_label = self.tokenier(processed_data,
                                                                                         labels)
        logger.info('处理后的数据量为 %d 条' % (len(train_label) + len(test_label)))
        logger.info('训练集的数据量为 %d 条' % len(train_label))
        logger.info('测试集的数据量为 %d 条' % len(test_label))
        self.embedding_matrix = self.get_embeddingMatrix(self.vocab)
        joblib.dump(self.embedding_matrix, filename=self.embedding_config['embedding_path'])
        return x_train_padded_seqs, train_label, x_test_padded_seqs, test_label
