diff --git a/main.py b/main.py index 62b0088..143216f 100644 --- a/main.py +++ b/main.py @@ -20,10 +20,14 @@ def get_data_from_db(): return [BatchDataRead.from_orm(db_obj).dict() for db_obj in db_objs] -st.session_state.setdefault('data_table', get_data_from_db()) -st.session_state.setdefault('configs', {}) +st.session_state.setdefault('data_table', []) st.session_state.setdefault('username', '') st.session_state.setdefault('evolve_r', 0.1) +st.session_state.setdefault('configs', { + 'evolve_r': st.session_state.evolve_r +}) +if not st.session_state.data_table: + st.session_state.data_table = get_data_from_db() df = pd.DataFrame(data=st.session_state.data_table) @@ -39,13 +43,19 @@ def train(): def update_handler(): edited_rows = st.session_state['edited_info'].get('edited_rows') for id_, update_data in edited_rows.items(): - row_db = session.query(BatchData).where(BatchData.id == int(edited_df.loc[id_].id)).first() - logger.info(f"update: {update_data}") + row_id = int(edited_df.loc[id_].id) + row_db = session.query(BatchData).where(BatchData.id == row_id).first() + logger.info(f"{row_id=}, {update_data=}") for field in update_data: setattr(row_db, field, update_data[field]) session.commit() +def update_config(*args, **kwargs): + logger.debug(f"{args=}, {kwargs=}, {st.session_state.configs=}, {st.session_state.evolve_r=}") + st.session_state.configs['evolve_r'] = st.session_state.evolve_r + + with data_frame_container: edited_df = st.data_editor( df, key="edited_info", @@ -63,14 +73,15 @@ with data_frame_container: }) with config_container: - # st.empty() - st.session_state.configs['evolve_r'] = st.session_state.evolve_r st.json(st.session_state.configs) with st.sidebar: st.divider() - st.session_state.evolve_r = st.slider('evolve_r', min_value=0.0, max_value=0.5, value=st.session_state.evolve_r, step=0.01) + st.slider(label='evolve_r', key='evolve_r', + min_value=0.0, max_value=0.5, step=0.01, + on_change=update_config) + logger.debug(st.session_state.evolve_r) st.text_input("username", key='username') st.divider() st.button("启动", on_click=train)