参数映射
This commit is contained in:
20
main.py
20
main.py
@@ -22,6 +22,10 @@ st.set_page_config(
|
|||||||
initial_sidebar_state="expanded",
|
initial_sidebar_state="expanded",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@st.cache_data
|
||||||
|
def list_ckpt_paths(dir_path):
|
||||||
|
return os.listdir(dir_path)
|
||||||
|
|
||||||
|
|
||||||
def get_data_from_db():
|
def get_data_from_db():
|
||||||
logger.debug("init")
|
logger.debug("init")
|
||||||
@@ -32,11 +36,12 @@ def get_data_from_db():
|
|||||||
st.session_state.setdefault('data_table', [])
|
st.session_state.setdefault('data_table', [])
|
||||||
st.session_state.setdefault('username', '')
|
st.session_state.setdefault('username', '')
|
||||||
st.session_state.setdefault('evolve_r', 0.2)
|
st.session_state.setdefault('evolve_r', 0.2)
|
||||||
|
st.session_state.setdefault('n_trail', 10)
|
||||||
st.session_state.setdefault('n_epoch', 2)
|
st.session_state.setdefault('n_epoch', 2)
|
||||||
|
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
||||||
st.session_state.setdefault('mode', 2)
|
st.session_state.setdefault('mode', 2)
|
||||||
st.session_state.setdefault('configs', {
|
st.session_state.setdefault('configs', {
|
||||||
'evolve_r': st.session_state.evolve_r,
|
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
||||||
'mode': st.session_state.get('mode')
|
|
||||||
})
|
})
|
||||||
if not st.session_state.data_table:
|
if not st.session_state.data_table:
|
||||||
st.session_state.data_table = get_data_from_db()
|
st.session_state.data_table = get_data_from_db()
|
||||||
@@ -65,8 +70,9 @@ def update_handler():
|
|||||||
|
|
||||||
|
|
||||||
def update_config(*args, **kwargs):
|
def update_config(*args, **kwargs):
|
||||||
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
return None
|
||||||
st.session_state.configs['mode'] = st.session_state.mode
|
# for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode'):
|
||||||
|
# st.session_state.configs[k] = st.session_state[k]
|
||||||
|
|
||||||
|
|
||||||
with left_col:
|
with left_col:
|
||||||
@@ -97,12 +103,14 @@ with right_col:
|
|||||||
st.selectbox(label='n_epoch', key='n_epoch',
|
st.selectbox(label='n_epoch', key='n_epoch',
|
||||||
options=(i for i in range(1, 6)))
|
options=(i for i in range(1, 6)))
|
||||||
st.selectbox(label='ckpt_path', key='ckpt_path',
|
st.selectbox(label='ckpt_path', key='ckpt_path',
|
||||||
options=(os.listdir(BASE_CKPT_DIR)))
|
options=(list_ckpt_paths(BASE_CKPT_DIR)))
|
||||||
st.selectbox(label='mode', key='mode',
|
st.selectbox(label='mode', key='mode',
|
||||||
format_func=lambda x: RUN_MODE[x],
|
format_func=lambda x: RUN_MODE[x],
|
||||||
options=RUN_MODE,
|
options=RUN_MODE,
|
||||||
on_change=update_config)
|
on_change=update_config)
|
||||||
st.json(st.session_state.configs)
|
st.json({
|
||||||
|
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
||||||
|
})
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
st.button("启动", on_click=train)
|
st.button("启动", on_click=train)
|
||||||
|
|||||||
Reference in New Issue
Block a user