This commit is contained in:
leo
2024-03-07 23:11:04 +08:00
parent 09b90a09d7
commit f22d134fb4

51
main.py
View File

@@ -43,42 +43,31 @@ def make_train_dicts(with_entities: tuple, conditions: tuple):
def prepare_train_data(mode_id) -> dict[str, list]: def prepare_train_data(mode_id) -> dict[str, list]:
res = { train_we, train_cd = tuple(), tuple()
"train_data": [], val_we, val_cd = tuple(), tuple()
"val_data": [],
}
if mode_id == 1: if mode_id == 1:
res['train_data'] = make_train_dicts( train_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')), train_cd = (BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None)) val_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
) val_cd = (BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
)
elif mode_id == 2: elif mode_id == 2:
res['train_data'] = make_train_dicts( train_we = (BatchData.img_prefix,
with_entities=(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file')), else_=BatchData.ann_file).label('ann_file'))
conditions=(BatchData.is_train == 1,) train_cd = (BatchData.is_train == 1,)
) val_we = (BatchData.img_prefix,
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file')), else_=BatchData.ann_file).label('ann_file'))
conditions=(BatchData.is_validation == 1,) val_cd = (BatchData.is_validation == 1,)
)
elif mode_id == 3: elif mode_id == 3:
res['train_data'] = make_train_dicts( train_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')), train_cd = (BatchData.is_train == 1, BatchData.ann_file.is_not(None))
conditions=(BatchData.is_train == 1, BatchData.ann_file.is_not(None)) val_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
) val_cd = (BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
res['val_data'] = make_train_dicts( return {
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')), "train_data": make_train_dicts(train_we, train_cd),
conditions=(BatchData.is_validation == 1, BatchData.ann_file.is_not(None)) "val_data": make_train_dicts(val_we, val_cd),
) }
return res
@st.cache_data @st.cache_data