优化1
This commit is contained in:
51
main.py
51
main.py
@@ -43,42 +43,31 @@ def make_train_dicts(with_entities: tuple, conditions: tuple):
|
||||
|
||||
|
||||
def prepare_train_data(mode_id) -> dict[str, list]:
|
||||
res = {
|
||||
"train_data": [],
|
||||
"val_data": [],
|
||||
}
|
||||
train_we, train_cd = tuple(), tuple()
|
||||
val_we, val_cd = tuple(), tuple()
|
||||
if mode_id == 1:
|
||||
res['train_data'] = make_train_dicts(
|
||||
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
|
||||
conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
|
||||
)
|
||||
res['val_data'] = make_train_dicts(
|
||||
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
|
||||
conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
|
||||
)
|
||||
train_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
|
||||
train_cd = (BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
|
||||
val_we = (BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file'))
|
||||
val_cd = (BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
|
||||
elif mode_id == 2:
|
||||
res['train_data'] = make_train_dicts(
|
||||
with_entities=(BatchData.img_prefix,
|
||||
train_we = (BatchData.img_prefix,
|
||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||
else_=BatchData.ann_file).label('ann_file')),
|
||||
conditions=(BatchData.is_train == 1,)
|
||||
)
|
||||
res['val_data'] = make_train_dicts(
|
||||
with_entities=(BatchData.img_prefix,
|
||||
else_=BatchData.ann_file).label('ann_file'))
|
||||
train_cd = (BatchData.is_train == 1,)
|
||||
val_we = (BatchData.img_prefix,
|
||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||
else_=BatchData.ann_file).label('ann_file')),
|
||||
conditions=(BatchData.is_validation == 1,)
|
||||
)
|
||||
else_=BatchData.ann_file).label('ann_file'))
|
||||
val_cd = (BatchData.is_validation == 1,)
|
||||
elif mode_id == 3:
|
||||
res['train_data'] = make_train_dicts(
|
||||
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')),
|
||||
conditions=(BatchData.is_train == 1, BatchData.ann_file.is_not(None))
|
||||
)
|
||||
res['val_data'] = make_train_dicts(
|
||||
with_entities=(BatchData.img_prefix, BatchData.ann_file.label('ann_file')),
|
||||
conditions=(BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
|
||||
)
|
||||
return res
|
||||
train_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
|
||||
train_cd = (BatchData.is_train == 1, BatchData.ann_file.is_not(None))
|
||||
val_we = (BatchData.img_prefix, BatchData.ann_file.label('ann_file'))
|
||||
val_cd = (BatchData.is_validation == 1, BatchData.ann_file.is_not(None))
|
||||
return {
|
||||
"train_data": make_train_dicts(train_we, train_cd),
|
||||
"val_data": make_train_dicts(val_we, val_cd),
|
||||
}
|
||||
|
||||
|
||||
@st.cache_data
|
||||
|
||||
Reference in New Issue
Block a user