实现参数拖拽
This commit is contained in:
25
main.py
25
main.py
@@ -20,10 +20,14 @@ def get_data_from_db():
|
|||||||
return [BatchDataRead.from_orm(db_obj).dict() for db_obj in db_objs]
|
return [BatchDataRead.from_orm(db_obj).dict() for db_obj in db_objs]
|
||||||
|
|
||||||
|
|
||||||
st.session_state.setdefault('data_table', get_data_from_db())
|
st.session_state.setdefault('data_table', [])
|
||||||
st.session_state.setdefault('configs', {})
|
|
||||||
st.session_state.setdefault('username', '')
|
st.session_state.setdefault('username', '')
|
||||||
st.session_state.setdefault('evolve_r', 0.1)
|
st.session_state.setdefault('evolve_r', 0.1)
|
||||||
|
st.session_state.setdefault('configs', {
|
||||||
|
'evolve_r': st.session_state.evolve_r
|
||||||
|
})
|
||||||
|
if not st.session_state.data_table:
|
||||||
|
st.session_state.data_table = get_data_from_db()
|
||||||
|
|
||||||
df = pd.DataFrame(data=st.session_state.data_table)
|
df = pd.DataFrame(data=st.session_state.data_table)
|
||||||
|
|
||||||
@@ -39,13 +43,19 @@ def train():
|
|||||||
def update_handler():
|
def update_handler():
|
||||||
edited_rows = st.session_state['edited_info'].get('edited_rows')
|
edited_rows = st.session_state['edited_info'].get('edited_rows')
|
||||||
for id_, update_data in edited_rows.items():
|
for id_, update_data in edited_rows.items():
|
||||||
row_db = session.query(BatchData).where(BatchData.id == int(edited_df.loc[id_].id)).first()
|
row_id = int(edited_df.loc[id_].id)
|
||||||
logger.info(f"update: {update_data}")
|
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
|
||||||
|
logger.info(f"{row_id=}, {update_data=}")
|
||||||
for field in update_data:
|
for field in update_data:
|
||||||
setattr(row_db, field, update_data[field])
|
setattr(row_db, field, update_data[field])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def update_config(*args, **kwargs):
|
||||||
|
logger.debug(f"{args=}, {kwargs=}, {st.session_state.configs=}, {st.session_state.evolve_r=}")
|
||||||
|
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
||||||
|
|
||||||
|
|
||||||
with data_frame_container:
|
with data_frame_container:
|
||||||
edited_df = st.data_editor(
|
edited_df = st.data_editor(
|
||||||
df, key="edited_info",
|
df, key="edited_info",
|
||||||
@@ -63,14 +73,15 @@ with data_frame_container:
|
|||||||
})
|
})
|
||||||
|
|
||||||
with config_container:
|
with config_container:
|
||||||
# st.empty()
|
|
||||||
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
|
||||||
st.json(st.session_state.configs)
|
st.json(st.session_state.configs)
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
st.session_state.evolve_r = st.slider('evolve_r', min_value=0.0, max_value=0.5, value=st.session_state.evolve_r, step=0.01)
|
st.slider(label='evolve_r', key='evolve_r',
|
||||||
|
min_value=0.0, max_value=0.5, step=0.01,
|
||||||
|
on_change=update_config)
|
||||||
|
logger.debug(st.session_state.evolve_r)
|
||||||
st.text_input("username", key='username')
|
st.text_input("username", key='username')
|
||||||
st.divider()
|
st.divider()
|
||||||
st.button("启动", on_click=train)
|
st.button("启动", on_click=train)
|
||||||
|
|||||||
Reference in New Issue
Block a user