实现mode 1
This commit is contained in:
38
main.py
38
main.py
@@ -4,6 +4,8 @@ import os
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_
|
||||
|
||||
from bak.init_data import BatchDataCreate, BatchDataRead
|
||||
from db_utils import BatchData, session
|
||||
@@ -23,6 +25,35 @@ st.set_page_config(
|
||||
)
|
||||
|
||||
|
||||
class DataRow(BaseModel):
|
||||
path: str | None = None
|
||||
img_prefix: str | None = None
|
||||
filter_empty_gt: bool = False
|
||||
update_cache: bool = False
|
||||
|
||||
|
||||
def make_train_data(mode_id) -> dict[str, list]:
|
||||
res = {
|
||||
"train_data": [],
|
||||
"val_data": [],
|
||||
}
|
||||
if mode_id == 1:
|
||||
for db_obj in session.query(BatchData).where(and_(
|
||||
BatchData.is_train == 1,
|
||||
BatchData.ann_file_lbs.is_not(None)
|
||||
)).all():
|
||||
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||
img_prefix=db_obj.img_prefix).dict())
|
||||
|
||||
|
||||
|
||||
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
||||
for db_obj in val_db_objs:
|
||||
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||
img_prefix=db_obj.img_prefix).dict())
|
||||
return res
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def list_ckpt_paths(dir_path):
|
||||
return os.listdir(dir_path)
|
||||
@@ -36,11 +67,11 @@ def get_data_from_db():
|
||||
|
||||
st.session_state.setdefault('data_table', [])
|
||||
st.session_state.setdefault('username', '')
|
||||
st.session_state.setdefault('evolve_r', 0.2)
|
||||
st.session_state.setdefault('evolve_r', 0.02)
|
||||
st.session_state.setdefault('n_trail', 10)
|
||||
st.session_state.setdefault('n_epoch', 2)
|
||||
st.session_state.setdefault('n_epoch', 1)
|
||||
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
||||
st.session_state.setdefault('mode', 2)
|
||||
st.session_state.setdefault('mode', 1)
|
||||
st.session_state.setdefault('configs', {
|
||||
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
||||
})
|
||||
@@ -52,6 +83,7 @@ left_col, right_col = st.columns([3, 1])
|
||||
|
||||
def train():
|
||||
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
|
||||
right_col.json(make_train_data(st.session_state.mode))
|
||||
|
||||
|
||||
def table_update_handler():
|
||||
|
||||
Reference in New Issue
Block a user