#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:33
# @Author  : 程婷婷
# @FileName: XgboostClassifyRunner.py
# @Software: PyCharm
import numpy as np
from model.base.views.runner.BaseRunner import BaseRunner
from model.classify.views.textcnn_classify.data.TextcnnClassifyProcess import TextcnnClassifyProcess
from model.classify.views.textcnn_classify.TextcnnClassifyModel import TextcnnClassifyModel
from model.classify.views.textcnn_classify.TextcnnClassifyEvaluator import TextcnnClassifyEvaluator


class TextcnnClassifyRunner(BaseRunner):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.tcp = TextcnnClassifyProcess(config_path)
        self.tcm = TextcnnClassifyModel(config_path)
        self.tce = TextcnnClassifyEvaluator(config_path)

    def train(self, logger):
        x_train_padded_seqs, train_label, x_test_padded_seqs, test_label = self.tcp.runner_process(logger)
        classes_weight = self.tcp.class_weight(train_label)
        print(classes_weight)
        model = self.tcm.building_model(
            x_train_padded_seqs=x_train_padded_seqs,
            y_train=train_label,
            x_test_padded_seqs=x_test_padded_seqs,
            y_test=test_label,
            embedding_matrix=self.tcp.embedding_matrix,
            classes_weight=classes_weight,
            vocab=self.tcp.vocab
        )
        result = model.predict(x_test_padded_seqs)  # 预测样本属于每个类别的概率
        predict_label = np.argmax(result, axis=1)  # 获得最大概率对应的标签
        self.tce.evaluate(test_label, predict_label, self.tcp.label_mapping, logger)
        return 'success'


# if __name__ == '__main__':
#     state = TextcnnClassifyRunner().train()
#     print(state)
