#!/usr/bin/python3
# -*- coding: utf-8 -*-
# @Time    : 2021/5/11 20:33
# @Author  : 程婷婷
# @FileName: XgboostClassifyRunner.py
# @Software: PyCharm
import os
import numpy as np
import torch
import random
from model.base.views.runner import BaseRunner
from model.classify.views.flair_classify import FlairClassifyProcess
from model.classify.views.flair_classify import FlairClassifyModel
from model.classify.views.flair_classify import FlairClassifyEvaluator

class FlairClassifyRunner(BaseRunner):
    def __init__(self, config_path):
        super().__init__(config_path)
        self.fcp = FlairClassifyProcess(config_path)
        self.fcm = FlairClassifyModel(config_path)
        self.fce = FlairClassifyEvaluator(config_path)

    def reproducibility(seed):
        '''
        固定随机种子
        :param seed:
        :return:
        '''
        os.environ["PYTHONHASHSEED"] = str(seed)
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def train(self):
        corpus, document_embeddings, label_dict, loss_weights = self.fcp.runner_process()
        model = self.fcm.building_model(
            corpus=corpus,
            document_embeddings=document_embeddings,
            label_dict=label_dict,
            loss_weights=loss_weights
        )
        #self.fce.evaluate(true_labels, predict_labels)
        return 'success'
if __name__ == '__main__':
    state = FlairClassifyRunner().train()
    print(state)