From da511782117d2624d22500220308b6adf6df0c30 Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 7 Mar 2024 23:04:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96mode=201,2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 58 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/main.py b/main.py index 4a2e9b4..6d8a11d 100644 --- a/main.py +++ b/main.py @@ -32,50 +32,52 @@ class DataRow(BaseModel): update_cache: bool = False -def make_train_dicts(conditions: tuple): +def make_train_dicts(with_entities: tuple, conditions: tuple): return [ - DataRow(path=_.ann_file_lbs, + DataRow(path=_.ann_file, img_prefix=_.img_prefix).dict() - for _ in session.query(BatchData).where(and_(*conditions)).all() + for _ in session.query(BatchData) + .with_entities(*with_entities) + .where(and_(*conditions)).all() ] -def make_train_data(mode_id) -> dict[str, list]: +def prepare_train_data(mode_id) -> dict[str, list]: res = { "train_data": [], "val_data": [], } if mode_id == 1: + res['train_data'] = make_train_dicts( + with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')), + conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None)) + ) + res['val_data'] = make_train_dicts( + with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')), + conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None)) + ) + elif mode_id == 2: + res['train_data'] = make_train_dicts( + with_entities=(BatchData.img_prefix, + case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), + else_=BatchData.ann_file).label('ann_file')), + conditions=(BatchData.is_train == 1,) + ) + res['val_data'] = make_train_dicts( + with_entities=(BatchData.img_prefix, + case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), + else_=BatchData.ann_file).label('ann_file')), + conditions=(BatchData.is_validation == 1,) + ) + elif mode_id == 3: res['train_data'] = make_train_dicts(conditions=( BatchData.is_train == 1, - BatchData.ann_file_lbs.is_not(None) + BatchData.ann_file.is_not(None) )) val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() - for db_obj in val_db_objs: - res['val_data'].append(DataRow(path=db_obj.ann_file_lbs, - img_prefix=db_obj.img_prefix).dict()) - - elif mode_id == 2: - results = (session.query(BatchData) - .with_entities(BatchData.img_prefix, - case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), - else_=BatchData.ann_file).label('ann_file')) - .where(BatchData.is_train == 1).all()) - - res['train_data'] = [ - DataRow(path=_.ann_file, - img_prefix=_.img_prefix).dict() - for _ in results - ] - val_db_objs = (session.query(BatchData) - .with_entities(BatchData.img_prefix, - case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), - else_=BatchData.ann_file).label('ann_file')) - .where(BatchData.is_validation == 1).all()) for db_obj in val_db_objs: res['val_data'].append(DataRow(path=db_obj.ann_file, img_prefix=db_obj.img_prefix).dict()) - return res @@ -108,7 +110,7 @@ left_col, right_col = st.columns([3, 1]) def train(): right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')}) - right_col.json(make_train_data(st.session_state.mode)) + right_col.json(prepare_train_data(st.session_state.mode)) def table_update_handler():