From 09b90a09d7fc25c1861ee773bce9e32776c8374c Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 7 Mar 2024 23:06:03 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0mode=203?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 6d8a11d..d0e4969 100644 --- a/main.py +++ b/main.py @@ -70,14 +70,14 @@ def prepare_train_data(mode_id) -> dict[str, list]: conditions=(BatchData.is_validation == 1,) ) elif mode_id == 3: - res['train_data'] = make_train_dicts(conditions=( - BatchData.is_train == 1, - BatchData.ann_file.is_not(None) - )) - val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() - for db_obj in val_db_objs: - res['val_data'].append(DataRow(path=db_obj.ann_file, - img_prefix=db_obj.img_prefix).dict()) + res['train_data'] = make_train_dicts( + with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')), + conditions=(BatchData.is_train == 1, BatchData.ann_file.is_not(None)) + ) + res['val_data'] = make_train_dicts( + with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')), + conditions=(BatchData.is_validation == 1, BatchData.ann_file.is_not(None)) + ) return res