定义基础参数

This commit is contained in:
leo
2024-03-06 23:55:49 +08:00
parent b83d032474
commit 230eafb5e3

26
main.py
View File

@@ -1,4 +1,6 @@
# streamlit_app.py # streamlit_app.py
import os
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
@@ -6,6 +8,12 @@ from loguru import logger
from bak.init_data import BatchDataRead from bak.init_data import BatchDataRead
from db_utils import BatchData, session from db_utils import BatchData, session
BASE_CKPT_DIR = "./bak"
RUN_MODE = {
1: 'lbs模式',
2: 'lbs优先模式',
3: '不使用lbs'
}
PAGE_TITLE = "training data configer" PAGE_TITLE = "training data configer"
st.set_page_config( st.set_page_config(
page_title=PAGE_TITLE, page_title=PAGE_TITLE,
@@ -22,9 +30,12 @@ def get_data_from_db():
st.session_state.setdefault('data_table', []) st.session_state.setdefault('data_table', [])
st.session_state.setdefault('username', '') st.session_state.setdefault('username', '')
st.session_state.setdefault('evolve_r', 0.1) st.session_state.setdefault('evolve_r', 0.2)
st.session_state.setdefault('n_epoch', 2)
st.session_state.setdefault('mode', 2)
st.session_state.setdefault('configs', { st.session_state.setdefault('configs', {
'evolve_r': st.session_state.evolve_r 'evolve_r': st.session_state.evolve_r,
'mode': st.session_state.get('mode')
}) })
if not st.session_state.data_table: if not st.session_state.data_table:
st.session_state.data_table = get_data_from_db() st.session_state.data_table = get_data_from_db()
@@ -53,6 +64,7 @@ def update_handler():
def update_config(*args, **kwargs): def update_config(*args, **kwargs):
st.session_state.configs['evolve_r'] = st.session_state.evolve_r st.session_state.configs['evolve_r'] = st.session_state.evolve_r
st.session_state.configs['mode'] = st.session_state.mode
with data_frame_container: with data_frame_container:
@@ -80,5 +92,15 @@ with st.sidebar:
st.slider(label='evolve_r', key='evolve_r', st.slider(label='evolve_r', key='evolve_r',
min_value=0.0, max_value=0.5, step=0.01, min_value=0.0, max_value=0.5, step=0.01,
on_change=update_config) on_change=update_config)
st.selectbox(label='n_trail', key='n_trail',
options=(i for i in range(10, 51, 10)))
st.selectbox(label='n_epoch', key='n_epoch',
options=(i for i in range(1, 6)))
st.selectbox(label='ckpt_path', key='ckpt_path',
options=(os.listdir(BASE_CKPT_DIR)))
st.selectbox(label='mode', key='mode',
format_func=lambda x: RUN_MODE[x],
options=RUN_MODE,
on_change=update_config)
st.divider() st.divider()
st.button("启动", on_click=train) st.button("启动", on_click=train)