#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:14
# @Author  : 程婷婷
# @FileName: XgboostClassifyProcess.py
# @Software: PyCharm
import re
import time
from model.base import BaseDataProcess
from model.classify import FastTextDataLoader

class FastTextProcess(BaseDataProcess.BaseDataProcess):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.ftdl = FastTextDataLoader.FastTextDataLoader(config_path)

    def remove_char(self, content):
        graph_filter = re.compile(u'[\U00010000-\U0010ffff\uD800-\uDBFF\uDC00-\uDFFFa-z\n\s]')
        content = graph_filter.sub('', content)
        return content

    def process(self, data, min_content):
        processed_data = []
        i = 0
        for record in data:
            record = self.remove_char(record)
            if len(record) > min_content:
                methods = self.process_config['tokenizer']
                if methods == 'PerceptronLexicalAnalyzer':
                    record = self.pla_tokenizer(record)
                else:
                    record = self.jieba_tokenizer(record)
                processed_data.append(record)
                i += 1
            else:
                i += 1
                pass
            if (i+1)%100 == 0 or i+1 == len(data):
                print(time.strftime('%Y-%m-%d %H:%M:%S'),'第',i+1,'条文本分词完毕')
        return processed_data

    def transform_data(self, data, labels):
        format_data = []
        for i in range(len(data)):
            fasttext_line = "__label__{} {}".format(labels[i], data[i])
            format_data.append(fasttext_line)
        return format_data

    def runner_process(self, logger):
        df = self.ftdl.read_file()
        processed_data = self.process(df['content'], min_content=10)
        # if self.process_config['label_encode']:
        if type(df['label'][0]) == int:
            labels = df['label']
        else:
            all_label = list(set(df['label']))
            self.label_mapping = {v: k for k, v in dict(enumerate(all_label)).items()}
            labels = df['label'].map(self.label_mapping)
        print(labels)
        fomat_data = self.transform_data(processed_data, labels)
        if self.process_config['use_dev']:
            train_data_set, test_data_set, dev_data_set = self.split_dataset(fomat_data, use_dev=self.process_config['use_dev'])
        else:
            train_data_set, test_data_set = self.split_dataset(fomat_data, use_dev=self.process_config['use_dev'])
        with open(self.process_config['train_file_path'], 'w', encoding='utf-8') as trainf, \
                open(self.process_config['test_file_path'], 'w', encoding='utf-8') as testf:
            for train_row in train_data_set:
                trainf.write(train_row + '\n')
            for test_row in test_data_set:
                testf.write(test_row + '\n')
        logger.info('处理后的数据量为 %d 条' % len(fomat_data))
        logger.info('训练集的数据量为 %d 条' % len(train_data_set))
        logger.info('测试集的数据量为 %d 条' % len(test_data_set))
        return self.process_config['train_file_path'], self.process_config['test_file_path']