From 8c78592f0a3c8d0014d565f7826cc5b42e4613d8 Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 7 Mar 2024 20:09:12 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=82=E6=95=B0=E6=98=A0=E5=B0=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data.db | Bin 28672 -> 28672 bytes main.py | 20 ++++++++++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/data.db b/data.db index 97a46087c7f17ebbb6a99f4e550c523530f52599..6a87e8554263d4743b2b25904bb786465eb4b1a2 100644 GIT binary patch delta 96 zcmZp8z}WDBae_2s+(a2?!8iuJv{$?g3=E8XD;W4z@a1kS{KLCBhi?TV6I1qPF#&HT w8SXa>TntR2!h)QQKJ5Gu7H4B{qc#Tz!{oox`!-t(Z{%PB8gXv(2Wf#20QSQfZ2$lO delta 105 zcmV-v0G9uN-~oW(0gxL3U6C9^5nTW-ZnzBq0003Fr~nSA4sEe<{0*~g4yXYF3+%HN z5K96X0002l00I>g5eW!^V}UUU2nXj&NJLOEGLx}B9Fx!*1e0JdjFVs(6|*}Nu?Pf# LO9u|K@EZ_RM)V$? diff --git a/main.py b/main.py index b4e1434..a82254c 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,10 @@ st.set_page_config( initial_sidebar_state="expanded", ) +@st.cache_data +def list_ckpt_paths(dir_path): + return os.listdir(dir_path) + def get_data_from_db(): logger.debug("init") @@ -32,11 +36,12 @@ def get_data_from_db(): st.session_state.setdefault('data_table', []) st.session_state.setdefault('username', '') st.session_state.setdefault('evolve_r', 0.2) +st.session_state.setdefault('n_trail', 10) st.session_state.setdefault('n_epoch', 2) +st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0]) st.session_state.setdefault('mode', 2) st.session_state.setdefault('configs', { - 'evolve_r': st.session_state.evolve_r, - 'mode': st.session_state.get('mode') + k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') }) if not st.session_state.data_table: st.session_state.data_table = get_data_from_db() @@ -65,8 +70,9 @@ def update_handler(): def update_config(*args, **kwargs): - st.session_state.configs['evolve_r'] = st.session_state.evolve_r - st.session_state.configs['mode'] = st.session_state.mode + return None + # for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode'): + # st.session_state.configs[k] = st.session_state[k] with left_col: @@ -97,12 +103,14 @@ with right_col: st.selectbox(label='n_epoch', key='n_epoch', options=(i for i in range(1, 6))) st.selectbox(label='ckpt_path', key='ckpt_path', - options=(os.listdir(BASE_CKPT_DIR))) + options=(list_ckpt_paths(BASE_CKPT_DIR))) st.selectbox(label='mode', key='mode', format_func=lambda x: RUN_MODE[x], options=RUN_MODE, on_change=update_config) - st.json(st.session_state.configs) + st.json({ + k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') + }) st.divider() st.button("启动", on_click=train)