定义基础参数
This commit is contained in:
26
main.py
26
main.py
@@ -1,4 +1,6 @@
|
|||||||
# streamlit_app.py
|
# streamlit_app.py
|
||||||
|
import os
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -6,6 +8,12 @@ from loguru import logger
|
|||||||
from bak.init_data import BatchDataRead
|
from bak.init_data import BatchDataRead
|
||||||
from db_utils import BatchData, session
|
from db_utils import BatchData, session
|
||||||
|
|
||||||
|
BASE_CKPT_DIR = "./bak"
|
||||||
|
RUN_MODE = {
|
||||||
|
1: 'lbs模式',
|
||||||
|
2: 'lbs优先模式',
|
||||||
|
3: '不使用lbs'
|
||||||
|
}
|
||||||
PAGE_TITLE = "training data configer"
|
PAGE_TITLE = "training data configer"
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
page_title=PAGE_TITLE,
|
page_title=PAGE_TITLE,
|
||||||
@@ -22,9 +30,12 @@ def get_data_from_db():
|
|||||||
|
|
||||||
st.session_state.setdefault('data_table', [])
|
st.session_state.setdefault('data_table', [])
|
||||||
st.session_state.setdefault('username', '')
|
st.session_state.setdefault('username', '')
|
||||||
st.session_state.setdefault('evolve_r', 0.1)
|
st.session_state.setdefault('evolve_r', 0.2)
|
||||||
|
st.session_state.setdefault('n_epoch', 2)
|
||||||
|
st.session_state.setdefault('mode', 2)
|
||||||
st.session_state.setdefault('configs', {
|
st.session_state.setdefault('configs', {
|
||||||
'evolve_r': st.session_state.evolve_r
|
'evolve_r': st.session_state.evolve_r,
|
||||||
|
'mode': st.session_state.get('mode')
|
||||||
})
|
})
|
||||||
if not st.session_state.data_table:
|
if not st.session_state.data_table:
|
||||||
st.session_state.data_table = get_data_from_db()
|
st.session_state.data_table = get_data_from_db()
|
||||||
@@ -53,6 +64,7 @@ def update_handler():
|
|||||||
|
|
||||||
def update_config(*args, **kwargs):
|
def update_config(*args, **kwargs):
|
||||||
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
||||||
|
st.session_state.configs['mode'] = st.session_state.mode
|
||||||
|
|
||||||
|
|
||||||
with data_frame_container:
|
with data_frame_container:
|
||||||
@@ -80,5 +92,15 @@ with st.sidebar:
|
|||||||
st.slider(label='evolve_r', key='evolve_r',
|
st.slider(label='evolve_r', key='evolve_r',
|
||||||
min_value=0.0, max_value=0.5, step=0.01,
|
min_value=0.0, max_value=0.5, step=0.01,
|
||||||
on_change=update_config)
|
on_change=update_config)
|
||||||
|
st.selectbox(label='n_trail', key='n_trail',
|
||||||
|
options=(i for i in range(10, 51, 10)))
|
||||||
|
st.selectbox(label='n_epoch', key='n_epoch',
|
||||||
|
options=(i for i in range(1, 6)))
|
||||||
|
st.selectbox(label='ckpt_path', key='ckpt_path',
|
||||||
|
options=(os.listdir(BASE_CKPT_DIR)))
|
||||||
|
st.selectbox(label='mode', key='mode',
|
||||||
|
format_func=lambda x: RUN_MODE[x],
|
||||||
|
options=RUN_MODE,
|
||||||
|
on_change=update_config)
|
||||||
st.divider()
|
st.divider()
|
||||||
st.button("启动", on_click=train)
|
st.button("启动", on_click=train)
|
||||||
|
|||||||
Reference in New Issue
Block a user