From 7128d489da0bcc3e6d3c4dda873adf57b653667c Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 7 Mar 2024 23:15:38 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90v1,=20=E6=95=B4=E7=90=86?= =?UTF-8?q?=E5=A5=BD=E6=89=80=E6=9C=89=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index f9a9876..47702d0 100644 --- a/main.py +++ b/main.py @@ -43,6 +43,7 @@ def make_train_dicts(with_entities: tuple, conditions: tuple): def prepare_train_data(mode_id) -> dict[str, list]: + # 训练集/测试集 实体类定义, 查询条件定义 train_we, train_cd = tuple(), tuple() val_we, val_cd = tuple(), tuple() if mode_id == 1: @@ -98,8 +99,11 @@ left_col, right_col = st.columns([3, 1]) def train(): - right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')}) - right_col.json(prepare_train_data(st.session_state.mode)) + train_config_dict = { + k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') + } + train_config_dict.update(prepare_train_data(st.session_state.mode)) + right_col.json(train_config_dict) def table_update_handler():