From f22d134fb45bb8fe71747f6ebd489249b1036412 Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 7 Mar 2024 23:11:04 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=961?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 55 ++++++++++++++++++++++--------------------------------- 1 file changed, 22 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index d0e4969..f9a9876 100644 --- a/main.py +++ b/main.py @@ -43,42 +43,31 @@ def make_train_dicts(with_entities: tuple, conditions: tuple): def prepare_train_data(mode_id) -> dict[str, list]: - res = { - "train_data": [], - "val_data": [], - } + train_we, train_cd = tuple(), tuple() + val_we, val_cd = tuple(), tuple() if mode_id == 1: - res['train_data'] = make_train_dicts( - with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')), - conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None)) - ) - res['val_data'] = make_train_dicts( - with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')), - conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None)) - ) + train_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')) + train_cd = (BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None)) + val_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')) + val_cd = (BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None)) elif mode_id == 2: - res['train_data'] = make_train_dicts( - with_entities=(BatchData.img_prefix, - case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), - else_=BatchData.ann_file).label('ann_file')), - conditions=(BatchData.is_train == 1,) - ) - res['val_data'] = make_train_dicts( - with_entities=(BatchData.img_prefix, - case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), - else_=BatchData.ann_file).label('ann_file')), - conditions=(BatchData.is_validation == 1,) - ) + train_we = (BatchData.img_prefix, + case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), + else_=BatchData.ann_file).label('ann_file')) + train_cd = (BatchData.is_train == 1,) + val_we = (BatchData.img_prefix, + case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), + else_=BatchData.ann_file).label('ann_file')) + val_cd = (BatchData.is_validation == 1,) elif mode_id == 3: - res['train_data'] = make_train_dicts( - with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')), - conditions=(BatchData.is_train == 1, BatchData.ann_file.is_not(None)) - ) - res['val_data'] = make_train_dicts( - with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')), - conditions=(BatchData.is_validation == 1, BatchData.ann_file.is_not(None)) - ) - return res + train_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file')) + train_cd = (BatchData.is_train == 1, BatchData.ann_file.is_not(None)) + val_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file')) + val_cd = (BatchData.is_validation == 1, BatchData.ann_file.is_not(None)) + return { + "train_data": make_train_dicts(train_we, train_cd), + "val_data": make_train_dicts(val_we, val_cd), + } @st.cache_data