This commit is contained in:
leo
2024-03-07 23:11:04 +08:00
parent 09b90a09d7
commit f22d134fb4

55
main.py
View File

@@ -43,42 +43,31 @@ def make_train_dicts(with_entities: tuple, conditions: tuple):
def prepare_train_data(mode_id) -> dict[str, list]:
res = {
"train_data": [],
"val_data": [],
}
train_we, train_cd = tuple(), tuple()
val_we, val_cd = tuple(), tuple()
if mode_id == 1:
res['train_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
)
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
)
train_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
train_cd = (BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
val_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
val_cd = (BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
elif mode_id == 2:
res['train_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file')),
conditions=(BatchData.is_train == 1,)
)
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file')),
conditions=(BatchData.is_validation == 1,)
)
train_we = (BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
train_cd = (BatchData.is_train == 1,)
val_we = (BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
val_cd = (BatchData.is_validation == 1,)
elif mode_id == 3:
res['train_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')),
conditions=(BatchData.is_train == 1, BatchData.ann_file.is_not(None))
)
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')),
conditions=(BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
)
return res
train_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
train_cd = (BatchData.is_train == 1, BatchData.ann_file.is_not(None))
val_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
val_cd = (BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
return {
"train_data": make_train_dicts(train_we, train_cd),
"val_data": make_train_dicts(val_we, val_cd),
}
@st.cache_data