#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @File     : data_process
# @Author   : bruxellse_li
# @Time     : 2023/3/31 08:39

import os
import pandas as pd
import sys

from pathlib import Path
from pandas import DataFrame
from sklearn.model_selection import train_test_split
# 追加工作路径
sys.path.append('../../')
from classification.utils.utils import *


def process_txt(data_loader: DataFrame, train_file_path: str, valid_file_path: str, stop_words_path:str):
    articles = data_loader['article']
    labels = data_loader['label']

    article_list = []
    for article, label in zip(articles, labels):
        if type(article) is str:
            text = article.replace('\n', '').replace('\r', '').replace('\t', '')
        else:
            print('{} is not str!'.format(article))
            continue
        text = seg(text=text, sw=stop_words(path=stop_words_path))
        text = '__label__{} {}'.format(label, text)
        article_list.append(text)

    train_data, valid_data = train_test_split(
        article_list, train_size=0.8, random_state=2021, shuffle=True
    )
    with open(
        train_file_path, 'w', encoding='utf-8'
    ) as train_file, open(
        valid_file_path, 'w', encoding='utf-8'
    ) as valid_file:
        for train in train_data:
            train_file.write(train + '\n')
        for valid in valid_data:
            valid_file.write(valid + '\n')
    pass


def process(data_loader, train_file_path: str, valid_file_path: str, stop_words_path: str):
    # 创建语料路径
    # Path(os.path.abspath(os.path.join(train_file_path, os.path.pardir))).mkdir(parents=True, exist_ok=True)

    # data_loader = pd.read_excel(path, keep_default_na=False).astype(str)
    data_loader['article'] = data_loader['title'] + '。' + data_loader['content']
    data_loader['article'] = data_loader.article.apply(clean_tag).apply(clean_txt)
    process_txt(
        data_loader=data_loader,
        train_file_path=train_file_path,
        valid_file_path=valid_file_path,
        stop_words_path=stop_words_path
    )
    return None


# 语料处理函数定义
def pro_data(modelName, dataFolderName, data_df, stop_words_path, save_data_path):
    # save_data_path = '/home/python/lzc/datasets/classification/{}/{}/{}.txt'
    process(
        data_loader=data_df,
        train_file_path=save_data_path.format(modelName, dataFolderName, 'train'),
        valid_file_path=save_data_path.format(modelName, dataFolderName, 'valid'),
        stop_words_path=stop_words_path
    )
    return None


if __name__ == '__main__':
    modelName, dataFolderName, data_path = "gzdt_dataset", "gzdt_V1", "../../datasets/Receive_File/测试数据.xlsx"
    save_data_path = r'../../datasets/classification/{}/{}/{}.txt'
    root_path = r'../../word2vec/doc_similarity/'
    stop_words_path = os.path.join(root_path, 'stop_words.txt')
    pro_data(modelName, dataFolderName, data_path, stop_words_path, save_data_path)
    # date = '20230329'
    # path = '../datasets/{}_total_{}.xlsx'
    #
    # save_data_path = '/home/zzsn/liuyan/datasets/the_belt_and_road/classification/{}/{}_{}.txt'
    # # 机械舆情 时事要闻栏目分类
    # ssyw_name = 'ssyw_column_classify'
    # # 机械舆情 国资动态栏目分类
    # gzdt_name = 'gzdt_column_classify'
    # # 机械舆情 上下游栏目分类
    # sxy_name = 'sxy_column_classify'
    # # 机械舆情 行业舆情栏目分类
    # hyyq_name = 'hyyq_column_classify'
    # # 机械舆情 管理动态栏目分类
    # gldt_name = 'gldt_column_classify'
    # # 机械舆情 龙头企业栏目分类
    # ltqy_name = 'ltqy_column_classify'
    # # 机械舆情 新兴领域栏目分类
    # xxly_name = 'xxly_column_classify'
    # # 机械舆情 综合资讯栏目分类
    # zhzx_name = 'zhzx_column_classify'
    # # 机械舆情 负面舆情栏目分类
    # fmyq_name = 'fmyq_column_classify'
    #
    # process(
    #     path=path.format(gzdt_name, date),
    #     train_file_path=save_data_path.format(gzdt_name, 'train', date),
    #     valid_file_path=save_data_path.format(gzdt_name, 'valid', date)
    # )
    # pass
