From bf2833e590a177ab5e29236c3e0544ccf20d487a Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 7 Mar 2024 22:32:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0mode=201?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 3 ++- main.py | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index cf49553..6585495 100644 --- a/README.md +++ b/README.md @@ -26,4 +26,5 @@ lbs_path 不生成 lbs模式: select(train==1, path=lbs_path) lbs优先模式: select(train==1, path=lbs_path if lbs_path else path) -不使用lbs模式: select(train==1, path=path) \ No newline at end of file +不使用lbs模式: select(train==1, path=path) + diff --git a/main.py b/main.py index 272f1b0..b08c4a4 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,8 @@ import os import pandas as pd import streamlit as st from loguru import logger +from pydantic import BaseModel +from sqlalchemy import and_ from bak.init_data import BatchDataCreate, BatchDataRead from db_utils import BatchData, session @@ -23,6 +25,35 @@ st.set_page_config( ) +class DataRow(BaseModel): + path: str | None = None + img_prefix: str | None = None + filter_empty_gt: bool = False + update_cache: bool = False + + +def make_train_data(mode_id) -> dict[str, list]: + res = { + "train_data": [], + "val_data": [], + } + if mode_id == 1: + for db_obj in session.query(BatchData).where(and_( + BatchData.is_train == 1, + BatchData.ann_file_lbs.is_not(None) + )).all(): + res['train_data'].append(DataRow(path=db_obj.ann_file_lbs, + img_prefix=db_obj.img_prefix).dict()) + + + + val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() + for db_obj in val_db_objs: + res['val_data'].append(DataRow(path=db_obj.ann_file_lbs, + img_prefix=db_obj.img_prefix).dict()) + return res + + @st.cache_data def list_ckpt_paths(dir_path): return os.listdir(dir_path) @@ -36,11 +67,11 @@ def get_data_from_db(): st.session_state.setdefault('data_table', []) st.session_state.setdefault('username', '') -st.session_state.setdefault('evolve_r', 0.2) +st.session_state.setdefault('evolve_r', 0.02) st.session_state.setdefault('n_trail', 10) -st.session_state.setdefault('n_epoch', 2) +st.session_state.setdefault('n_epoch', 1) st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0]) -st.session_state.setdefault('mode', 2) +st.session_state.setdefault('mode', 1) st.session_state.setdefault('configs', { k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') }) @@ -52,6 +83,7 @@ left_col, right_col = st.columns([3, 1]) def train(): right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')}) + right_col.json(make_train_data(st.session_state.mode)) def table_update_handler():