优化mode 1,2
This commit is contained in:
58
main.py
58
main.py
@@ -32,50 +32,52 @@ class DataRow(BaseModel):
|
|||||||
update_cache: bool = False
|
update_cache: bool = False
|
||||||
|
|
||||||
|
|
||||||
def make_train_dicts(conditions: tuple):
|
def make_train_dicts(with_entities: tuple, conditions: tuple):
|
||||||
return [
|
return [
|
||||||
DataRow(path=_.ann_file_lbs,
|
DataRow(path=_.ann_file,
|
||||||
img_prefix=_.img_prefix).dict()
|
img_prefix=_.img_prefix).dict()
|
||||||
for _ in session.query(BatchData).where(and_(*conditions)).all()
|
for _ in session.query(BatchData)
|
||||||
|
.with_entities(*with_entities)
|
||||||
|
.where(and_(*conditions)).all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_train_data(mode_id) -> dict[str, list]:
|
def prepare_train_data(mode_id) -> dict[str, list]:
|
||||||
res = {
|
res = {
|
||||||
"train_data": [],
|
"train_data": [],
|
||||||
"val_data": [],
|
"val_data": [],
|
||||||
}
|
}
|
||||||
if mode_id == 1:
|
if mode_id == 1:
|
||||||
|
res['train_data'] = make_train_dicts(
|
||||||
|
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
|
||||||
|
conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
|
||||||
|
)
|
||||||
|
res['val_data'] = make_train_dicts(
|
||||||
|
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
|
||||||
|
conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
|
||||||
|
)
|
||||||
|
elif mode_id == 2:
|
||||||
|
res['train_data'] = make_train_dicts(
|
||||||
|
with_entities=(BatchData.img_prefix,
|
||||||
|
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||||
|
else_=BatchData.ann_file).label('ann_file')),
|
||||||
|
conditions=(BatchData.is_train == 1,)
|
||||||
|
)
|
||||||
|
res['val_data'] = make_train_dicts(
|
||||||
|
with_entities=(BatchData.img_prefix,
|
||||||
|
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||||
|
else_=BatchData.ann_file).label('ann_file')),
|
||||||
|
conditions=(BatchData.is_validation == 1,)
|
||||||
|
)
|
||||||
|
elif mode_id == 3:
|
||||||
res['train_data'] = make_train_dicts(conditions=(
|
res['train_data'] = make_train_dicts(conditions=(
|
||||||
BatchData.is_train == 1,
|
BatchData.is_train == 1,
|
||||||
BatchData.ann_file_lbs.is_not(None)
|
BatchData.ann_file.is_not(None)
|
||||||
))
|
))
|
||||||
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
||||||
for db_obj in val_db_objs:
|
|
||||||
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
|
||||||
img_prefix=db_obj.img_prefix).dict())
|
|
||||||
|
|
||||||
elif mode_id == 2:
|
|
||||||
results = (session.query(BatchData)
|
|
||||||
.with_entities(BatchData.img_prefix,
|
|
||||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
|
||||||
else_=BatchData.ann_file).label('ann_file'))
|
|
||||||
.where(BatchData.is_train == 1).all())
|
|
||||||
|
|
||||||
res['train_data'] = [
|
|
||||||
DataRow(path=_.ann_file,
|
|
||||||
img_prefix=_.img_prefix).dict()
|
|
||||||
for _ in results
|
|
||||||
]
|
|
||||||
val_db_objs = (session.query(BatchData)
|
|
||||||
.with_entities(BatchData.img_prefix,
|
|
||||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
|
||||||
else_=BatchData.ann_file).label('ann_file'))
|
|
||||||
.where(BatchData.is_validation == 1).all())
|
|
||||||
for db_obj in val_db_objs:
|
for db_obj in val_db_objs:
|
||||||
res['val_data'].append(DataRow(path=db_obj.ann_file,
|
res['val_data'].append(DataRow(path=db_obj.ann_file,
|
||||||
img_prefix=db_obj.img_prefix).dict())
|
img_prefix=db_obj.img_prefix).dict())
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@@ -108,7 +110,7 @@ left_col, right_col = st.columns([3, 1])
|
|||||||
|
|
||||||
def train():
|
def train():
|
||||||
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
|
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
|
||||||
right_col.json(make_train_data(st.session_state.mode))
|
right_col.json(prepare_train_data(st.session_state.mode))
|
||||||
|
|
||||||
|
|
||||||
def table_update_handler():
|
def table_update_handler():
|
||||||
|
|||||||
Reference in New Issue
Block a user