优化1
This commit is contained in:
55
main.py
55
main.py
@@ -43,42 +43,31 @@ def make_train_dicts(with_entities: tuple, conditions: tuple):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_train_data(mode_id) -> dict[str, list]:
|
def prepare_train_data(mode_id) -> dict[str, list]:
|
||||||
res = {
|
train_we, train_cd = tuple(), tuple()
|
||||||
"train_data": [],
|
val_we, val_cd = tuple(), tuple()
|
||||||
"val_data": [],
|
|
||||||
}
|
|
||||||
if mode_id == 1:
|
if mode_id == 1:
|
||||||
res['train_data'] = make_train_dicts(
|
train_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
|
||||||
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
|
train_cd = (BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
|
||||||
conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
|
val_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
|
||||||
)
|
val_cd = (BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
|
||||||
res['val_data'] = make_train_dicts(
|
|
||||||
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
|
|
||||||
conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
|
|
||||||
)
|
|
||||||
elif mode_id == 2:
|
elif mode_id == 2:
|
||||||
res['train_data'] = make_train_dicts(
|
train_we = (BatchData.img_prefix,
|
||||||
with_entities=(BatchData.img_prefix,
|
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
else_=BatchData.ann_file).label('ann_file'))
|
||||||
else_=BatchData.ann_file).label('ann_file')),
|
train_cd = (BatchData.is_train == 1,)
|
||||||
conditions=(BatchData.is_train == 1,)
|
val_we = (BatchData.img_prefix,
|
||||||
)
|
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||||
res['val_data'] = make_train_dicts(
|
else_=BatchData.ann_file).label('ann_file'))
|
||||||
with_entities=(BatchData.img_prefix,
|
val_cd = (BatchData.is_validation == 1,)
|
||||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
|
||||||
else_=BatchData.ann_file).label('ann_file')),
|
|
||||||
conditions=(BatchData.is_validation == 1,)
|
|
||||||
)
|
|
||||||
elif mode_id == 3:
|
elif mode_id == 3:
|
||||||
res['train_data'] = make_train_dicts(
|
train_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
|
||||||
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')),
|
train_cd = (BatchData.is_train == 1, BatchData.ann_file.is_not(None))
|
||||||
conditions=(BatchData.is_train == 1, BatchData.ann_file.is_not(None))
|
val_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
|
||||||
)
|
val_cd = (BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
|
||||||
res['val_data'] = make_train_dicts(
|
return {
|
||||||
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')),
|
"train_data": make_train_dicts(train_we, train_cd),
|
||||||
conditions=(BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
|
"val_data": make_train_dicts(val_we, val_cd),
|
||||||
)
|
}
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
|
|||||||
Reference in New Issue
Block a user