支持增加 删除行

This commit is contained in:
leo
2024-03-07 20:48:30 +08:00
parent 71068f00dd
commit d785084cfd
2 changed files with 40 additions and 4 deletions

View File

@@ -25,3 +25,16 @@ class BatchDataRead(BatchDataBase):
orm_mode = True orm_mode = True
class BatchDataCreate(BatchDataBase):
year: int | None = 2024
census_batch: str | None = '' # 普查批次
id_code: str | None = '' # 编号
precision: str | None = '' # 精度
is_train: bool | None = True
is_validation: bool | None = False
ann_file: str | None = ''
ann_file_lbs: str | None = ''
img_prefix: str | None = ''
filter_empty_gt: bool | None = False
update_cache: bool | None = False

31
main.py
View File

@@ -5,7 +5,7 @@ import pandas as pd
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
from bak.init_data import BatchDataRead from bak.init_data import BatchDataCreate, BatchDataRead
from db_utils import BatchData, session from db_utils import BatchData, session
BASE_CKPT_DIR = "./bak" BASE_CKPT_DIR = "./bak"
@@ -56,22 +56,44 @@ def train():
st.session_state.configs['evolve_r'] = st.session_state.evolve_r st.session_state.configs['evolve_r'] = st.session_state.evolve_r
def update_handler(): def table_update_handler():
logger.debug(f"{st.session_state.edited_info=}")
edited_rows = st.session_state.edited_info.get('edited_rows') edited_rows = st.session_state.edited_info.get('edited_rows')
added_rows = st.session_state.edited_info.get('added_rows')
deleted_rows = st.session_state.edited_info.get('deleted_rows')
# 更改列
for id_, update_data in edited_rows.items(): for id_, update_data in edited_rows.items():
row_id = int(edited_df.loc[id_].id) row_id = int(edited_df.loc[id_].id)
row_db = session.query(BatchData).where(BatchData.id == row_id).first() row_db = session.query(BatchData).where(BatchData.id == row_id).first()
logger.info(f"{row_id=}, {update_data=}") logger.info(f"{row_id=}, {update_data=}")
logger.debug(BatchDataRead.from_orm(row_db))
for field in update_data: for field in update_data:
setattr(row_db, field, update_data[field]) setattr(row_db, field, update_data[field])
session.commit() session.commit()
# 增加列
for new_row_data in added_rows:
logger.debug(f"{new_row_data=}")
for deleted_row in deleted_rows:
row_id = int(edited_df.loc[deleted_row].id)
logger.debug(f"{row_id=}")
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
session.delete(row_db)
session.commit()
def update_config(*args, **kwargs): def update_config(*args, **kwargs):
return None return None
def create_new_row():
session.add(
BatchData(**BatchDataCreate().dict())
)
session.commit()
st.session_state.data_table = get_data_from_db()
with left_col: with left_col:
edited_df = st.data_editor( edited_df = st.data_editor(
df, key="edited_info", df, key="edited_info",
@@ -79,7 +101,7 @@ with left_col:
height=600, height=600,
hide_index=True, hide_index=True,
use_container_width=True, use_container_width=True,
on_change=update_handler, on_change=table_update_handler,
column_order=( column_order=(
'id', 'year', 'census_batch', 'id_code', 'precision', 'is_train', 'is_validation', 'ann_file', 'ann_file_lbs', 'img_prefix',), 'id', 'year', 'census_batch', 'id_code', 'precision', 'is_train', 'is_validation', 'ann_file', 'ann_file_lbs', 'img_prefix',),
column_config={ column_config={
@@ -93,6 +115,7 @@ with left_col:
'ann_file_lbs': st.column_config.Column("lbs_path", width='medium'), 'ann_file_lbs': st.column_config.Column("lbs_path", width='medium'),
'img_prefix': st.column_config.Column("img_prefix", width='medium'), 'img_prefix': st.column_config.Column("img_prefix", width='medium'),
}) })
st.button("新建一行", on_click=create_new_row)
with right_col: with right_col:
st.slider(label='evolve_r', key='evolve_r', st.slider(label='evolve_r', key='evolve_r',