支持增加 删除行
This commit is contained in:
31
main.py
31
main.py
@@ -5,7 +5,7 @@ import pandas as pd
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
from bak.init_data import BatchDataRead
|
||||
from bak.init_data import BatchDataCreate, BatchDataRead
|
||||
from db_utils import BatchData, session
|
||||
|
||||
BASE_CKPT_DIR = "./bak"
|
||||
@@ -56,22 +56,44 @@ def train():
|
||||
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
||||
|
||||
|
||||
def update_handler():
|
||||
def table_update_handler():
|
||||
logger.debug(f"{st.session_state.edited_info=}")
|
||||
edited_rows = st.session_state.edited_info.get('edited_rows')
|
||||
added_rows = st.session_state.edited_info.get('added_rows')
|
||||
deleted_rows = st.session_state.edited_info.get('deleted_rows')
|
||||
# 更改列
|
||||
for id_, update_data in edited_rows.items():
|
||||
row_id = int(edited_df.loc[id_].id)
|
||||
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
|
||||
logger.info(f"{row_id=}, {update_data=}")
|
||||
logger.debug(BatchDataRead.from_orm(row_db))
|
||||
for field in update_data:
|
||||
setattr(row_db, field, update_data[field])
|
||||
session.commit()
|
||||
|
||||
# 增加列
|
||||
for new_row_data in added_rows:
|
||||
logger.debug(f"{new_row_data=}")
|
||||
|
||||
for deleted_row in deleted_rows:
|
||||
row_id = int(edited_df.loc[deleted_row].id)
|
||||
logger.debug(f"{row_id=}")
|
||||
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
|
||||
session.delete(row_db)
|
||||
session.commit()
|
||||
|
||||
|
||||
def update_config(*args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def create_new_row():
|
||||
session.add(
|
||||
BatchData(**BatchDataCreate().dict())
|
||||
)
|
||||
session.commit()
|
||||
st.session_state.data_table = get_data_from_db()
|
||||
|
||||
|
||||
with left_col:
|
||||
edited_df = st.data_editor(
|
||||
df, key="edited_info",
|
||||
@@ -79,7 +101,7 @@ with left_col:
|
||||
height=600,
|
||||
hide_index=True,
|
||||
use_container_width=True,
|
||||
on_change=update_handler,
|
||||
on_change=table_update_handler,
|
||||
column_order=(
|
||||
'id', 'year', 'census_batch', 'id_code', 'precision', 'is_train', 'is_validation', 'ann_file', 'ann_file_lbs', 'img_prefix',),
|
||||
column_config={
|
||||
@@ -93,6 +115,7 @@ with left_col:
|
||||
'ann_file_lbs': st.column_config.Column("lbs_path", width='medium'),
|
||||
'img_prefix': st.column_config.Column("img_prefix", width='medium'),
|
||||
})
|
||||
st.button("新建一行", on_click=create_new_row)
|
||||
|
||||
with right_col:
|
||||
st.slider(label='evolve_r', key='evolve_r',
|
||||
|
||||
Reference in New Issue
Block a user