From 230eafb5e3a7893bd1b7493d63ea970bc4cf1d3b Mon Sep 17 00:00:00 2001 From: leo Date: Wed, 6 Mar 2024 23:55:49 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=9F=BA=E7=A1=80=E5=8F=82?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index e54ee7a..6fd3e5b 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,6 @@ # streamlit_app.py +import os + import pandas as pd import streamlit as st from loguru import logger @@ -6,6 +8,12 @@ from loguru import logger from bak.init_data import BatchDataRead from db_utils import BatchData, session +BASE_CKPT_DIR = "./bak" +RUN_MODE = { + 1: 'lbs模式', + 2: 'lbs优先模式', + 3: '不使用lbs' +} PAGE_TITLE = "training data configer" st.set_page_config( page_title=PAGE_TITLE, @@ -22,9 +30,12 @@ def get_data_from_db(): st.session_state.setdefault('data_table', []) st.session_state.setdefault('username', '') -st.session_state.setdefault('evolve_r', 0.1) +st.session_state.setdefault('evolve_r', 0.2) +st.session_state.setdefault('n_epoch', 2) +st.session_state.setdefault('mode', 2) st.session_state.setdefault('configs', { - 'evolve_r': st.session_state.evolve_r + 'evolve_r': st.session_state.evolve_r, + 'mode': st.session_state.get('mode') }) if not st.session_state.data_table: st.session_state.data_table = get_data_from_db() @@ -53,6 +64,7 @@ def update_handler(): def update_config(*args, **kwargs): st.session_state.configs['evolve_r'] = st.session_state.evolve_r + st.session_state.configs['mode'] = st.session_state.mode with data_frame_container: @@ -80,5 +92,15 @@ with st.sidebar: st.slider(label='evolve_r', key='evolve_r', min_value=0.0, max_value=0.5, step=0.01, on_change=update_config) + st.selectbox(label='n_trail', key='n_trail', + options=(i for i in range(10, 51, 10))) + st.selectbox(label='n_epoch', key='n_epoch', + options=(i for i in range(1, 6))) + st.selectbox(label='ckpt_path', key='ckpt_path', + options=(os.listdir(BASE_CKPT_DIR))) + st.selectbox(label='mode', key='mode', + format_func=lambda x: RUN_MODE[x], + options=RUN_MODE, + on_change=update_config) st.divider() st.button("启动", on_click=train)