支持增加 删除行
This commit is contained in:
@@ -25,3 +25,16 @@ class BatchDataRead(BatchDataBase):
|
|||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDataCreate(BatchDataBase):
|
||||||
|
year: int | None = 2024
|
||||||
|
census_batch: str | None = '' # 普查批次
|
||||||
|
id_code: str | None = '' # 编号
|
||||||
|
precision: str | None = '' # 精度
|
||||||
|
is_train: bool | None = True
|
||||||
|
is_validation: bool | None = False
|
||||||
|
|
||||||
|
ann_file: str | None = ''
|
||||||
|
ann_file_lbs: str | None = ''
|
||||||
|
img_prefix: str | None = ''
|
||||||
|
filter_empty_gt: bool | None = False
|
||||||
|
update_cache: bool | None = False
|
||||||
|
|||||||
31
main.py
31
main.py
@@ -5,7 +5,7 @@ import pandas as pd
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from bak.init_data import BatchDataRead
|
from bak.init_data import BatchDataCreate, BatchDataRead
|
||||||
from db_utils import BatchData, session
|
from db_utils import BatchData, session
|
||||||
|
|
||||||
BASE_CKPT_DIR = "./bak"
|
BASE_CKPT_DIR = "./bak"
|
||||||
@@ -56,22 +56,44 @@ def train():
|
|||||||
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
st.session_state.configs['evolve_r'] = st.session_state.evolve_r
|
||||||
|
|
||||||
|
|
||||||
def update_handler():
|
def table_update_handler():
|
||||||
|
logger.debug(f"{st.session_state.edited_info=}")
|
||||||
edited_rows = st.session_state.edited_info.get('edited_rows')
|
edited_rows = st.session_state.edited_info.get('edited_rows')
|
||||||
|
added_rows = st.session_state.edited_info.get('added_rows')
|
||||||
|
deleted_rows = st.session_state.edited_info.get('deleted_rows')
|
||||||
|
# 更改列
|
||||||
for id_, update_data in edited_rows.items():
|
for id_, update_data in edited_rows.items():
|
||||||
row_id = int(edited_df.loc[id_].id)
|
row_id = int(edited_df.loc[id_].id)
|
||||||
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
|
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
|
||||||
logger.info(f"{row_id=}, {update_data=}")
|
logger.info(f"{row_id=}, {update_data=}")
|
||||||
logger.debug(BatchDataRead.from_orm(row_db))
|
|
||||||
for field in update_data:
|
for field in update_data:
|
||||||
setattr(row_db, field, update_data[field])
|
setattr(row_db, field, update_data[field])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
# 增加列
|
||||||
|
for new_row_data in added_rows:
|
||||||
|
logger.debug(f"{new_row_data=}")
|
||||||
|
|
||||||
|
for deleted_row in deleted_rows:
|
||||||
|
row_id = int(edited_df.loc[deleted_row].id)
|
||||||
|
logger.debug(f"{row_id=}")
|
||||||
|
row_db = session.query(BatchData).where(BatchData.id == row_id).first()
|
||||||
|
session.delete(row_db)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
def update_config(*args, **kwargs):
|
def update_config(*args, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_new_row():
|
||||||
|
session.add(
|
||||||
|
BatchData(**BatchDataCreate().dict())
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
st.session_state.data_table = get_data_from_db()
|
||||||
|
|
||||||
|
|
||||||
with left_col:
|
with left_col:
|
||||||
edited_df = st.data_editor(
|
edited_df = st.data_editor(
|
||||||
df, key="edited_info",
|
df, key="edited_info",
|
||||||
@@ -79,7 +101,7 @@ with left_col:
|
|||||||
height=600,
|
height=600,
|
||||||
hide_index=True,
|
hide_index=True,
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
on_change=update_handler,
|
on_change=table_update_handler,
|
||||||
column_order=(
|
column_order=(
|
||||||
'id', 'year', 'census_batch', 'id_code', 'precision', 'is_train', 'is_validation', 'ann_file', 'ann_file_lbs', 'img_prefix',),
|
'id', 'year', 'census_batch', 'id_code', 'precision', 'is_train', 'is_validation', 'ann_file', 'ann_file_lbs', 'img_prefix',),
|
||||||
column_config={
|
column_config={
|
||||||
@@ -93,6 +115,7 @@ with left_col:
|
|||||||
'ann_file_lbs': st.column_config.Column("lbs_path", width='medium'),
|
'ann_file_lbs': st.column_config.Column("lbs_path", width='medium'),
|
||||||
'img_prefix': st.column_config.Column("img_prefix", width='medium'),
|
'img_prefix': st.column_config.Column("img_prefix", width='medium'),
|
||||||
})
|
})
|
||||||
|
st.button("新建一行", on_click=create_new_row)
|
||||||
|
|
||||||
with right_col:
|
with right_col:
|
||||||
st.slider(label='evolve_r', key='evolve_r',
|
st.slider(label='evolve_r', key='evolve_r',
|
||||||
|
|||||||
Reference in New Issue
Block a user