实现mode 1

This commit is contained in:
leo
2024-03-07 22:32:28 +08:00
parent 073904de36
commit bf2833e590
2 changed files with 37 additions and 4 deletions

View File

@@ -27,3 +27,4 @@ lbs_path 不生成
lbs模式: select(train==1, path=lbs_path) lbs模式: select(train==1, path=lbs_path)
lbs优先模式: select(train==1, path=lbs_path if lbs_path else path) lbs优先模式: select(train==1, path=lbs_path if lbs_path else path)
不使用lbs模式: select(train==1, path=path) 不使用lbs模式: select(train==1, path=path)

38
main.py
View File

@@ -4,6 +4,8 @@ import os
import pandas as pd import pandas as pd
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
from pydantic import BaseModel
from sqlalchemy import and_
from bak.init_data import BatchDataCreate, BatchDataRead from bak.init_data import BatchDataCreate, BatchDataRead
from db_utils import BatchData, session from db_utils import BatchData, session
@@ -23,6 +25,35 @@ st.set_page_config(
) )
class DataRow(BaseModel):
path: str | None = None
img_prefix: str | None = None
filter_empty_gt: bool = False
update_cache: bool = False
def make_train_data(mode_id) -> dict[str, list]:
res = {
"train_data": [],
"val_data": [],
}
if mode_id == 1:
for db_obj in session.query(BatchData).where(and_(
BatchData.is_train == 1,
BatchData.ann_file_lbs.is_not(None)
)).all():
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
return res
@st.cache_data @st.cache_data
def list_ckpt_paths(dir_path): def list_ckpt_paths(dir_path):
return os.listdir(dir_path) return os.listdir(dir_path)
@@ -36,11 +67,11 @@ def get_data_from_db():
st.session_state.setdefault('data_table', []) st.session_state.setdefault('data_table', [])
st.session_state.setdefault('username', '') st.session_state.setdefault('username', '')
st.session_state.setdefault('evolve_r', 0.2) st.session_state.setdefault('evolve_r', 0.02)
st.session_state.setdefault('n_trail', 10) st.session_state.setdefault('n_trail', 10)
st.session_state.setdefault('n_epoch', 2) st.session_state.setdefault('n_epoch', 1)
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0]) st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
st.session_state.setdefault('mode', 2) st.session_state.setdefault('mode', 1)
st.session_state.setdefault('configs', { st.session_state.setdefault('configs', {
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
}) })
@@ -52,6 +83,7 @@ left_col, right_col = st.columns([3, 1])
def train(): def train():
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')}) right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
right_col.json(make_train_data(st.session_state.mode))
def table_update_handler(): def table_update_handler():