diff --git a/bak/init_data.py b/bak/init_data.py index 9590782..1d4ea10 100644 --- a/bak/init_data.py +++ b/bak/init_data.py @@ -25,3 +25,16 @@ class BatchDataRead(BatchDataBase): orm_mode = True +class BatchDataCreate(BatchDataBase): + year: int | None = 2024 + census_batch: str | None = '' # 普查批次 + id_code: str | None = '' # 编号 + precision: str | None = '' # 精度 + is_train: bool | None = True + is_validation: bool | None = False + + ann_file: str | None = '' + ann_file_lbs: str | None = '' + img_prefix: str | None = '' + filter_empty_gt: bool | None = False + update_cache: bool | None = False diff --git a/main.py b/main.py index 62b25dc..2bd87f5 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import pandas as pd import streamlit as st from loguru import logger -from bak.init_data import BatchDataRead +from bak.init_data import BatchDataCreate, BatchDataRead from db_utils import BatchData, session BASE_CKPT_DIR = "./bak" @@ -56,22 +56,44 @@ def train(): st.session_state.configs['evolve_r'] = st.session_state.evolve_r -def update_handler(): +def table_update_handler(): + logger.debug(f"{st.session_state.edited_info=}") edited_rows = st.session_state.edited_info.get('edited_rows') + added_rows = st.session_state.edited_info.get('added_rows') + deleted_rows = st.session_state.edited_info.get('deleted_rows') + # 更改列 for id_, update_data in edited_rows.items(): row_id = int(edited_df.loc[id_].id) row_db = session.query(BatchData).where(BatchData.id == row_id).first() logger.info(f"{row_id=}, {update_data=}") - logger.debug(BatchDataRead.from_orm(row_db)) for field in update_data: setattr(row_db, field, update_data[field]) session.commit() + # 增加列 + for new_row_data in added_rows: + logger.debug(f"{new_row_data=}") + + for deleted_row in deleted_rows: + row_id = int(edited_df.loc[deleted_row].id) + logger.debug(f"{row_id=}") + row_db = session.query(BatchData).where(BatchData.id == row_id).first() + session.delete(row_db) + session.commit() + def update_config(*args, **kwargs): return None +def create_new_row(): + session.add( + BatchData(**BatchDataCreate().dict()) + ) + session.commit() + st.session_state.data_table = get_data_from_db() + + with left_col: edited_df = st.data_editor( df, key="edited_info", @@ -79,7 +101,7 @@ with left_col: height=600, hide_index=True, use_container_width=True, - on_change=update_handler, + on_change=table_update_handler, column_order=( 'id', 'year', 'census_batch', 'id_code', 'precision', 'is_train', 'is_validation', 'ann_file', 'ann_file_lbs', 'img_prefix',), column_config={ @@ -93,6 +115,7 @@ with left_col: 'ann_file_lbs': st.column_config.Column("lbs_path", width='medium'), 'img_prefix': st.column_config.Column("img_prefix", width='medium'), }) + st.button("新建一行", on_click=create_new_row) with right_col: st.slider(label='evolve_r', key='evolve_r',