Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion openless-all/app/src-tauri/src/asr/local/sherpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//!
//! 推理接入见 `sherpa_runtime.rs`(M2)。

use std::path::PathBuf;
use std::path::{Path, PathBuf};

use anyhow::Result;
use serde::Serialize;
Expand Down Expand Up @@ -128,6 +128,18 @@ pub fn required_files_for_alias(alias: &str) -> Result<&'static [&'static str]>
}
}

pub fn required_path_is_valid(alias: &str, required: &str, path: &Path) -> bool {
if required_path_is_dir(alias, required) {
path.is_dir()
} else {
path.is_file()
}
}

fn required_path_is_dir(alias: &str, required: &str) -> bool {
matches!((alias, required), ("qwen3-asr-0.6b-int8", "tokenizer"))
}

pub fn download_files_for_alias(alias: &str) -> Result<&'static [(&'static str, &'static str)]> {
match alias {
"sense-voice-small-zh" => Ok(&[
Expand Down Expand Up @@ -184,6 +196,7 @@ pub struct SherpaCatalogModel {
pub mode: SherpaMode,
pub languages: Vec<String>,
pub cached: bool,
pub downloaded_bytes: u64,
pub file_size_mb: Option<u64>,
}

Expand All @@ -197,6 +210,7 @@ impl SherpaCatalogModel {
mode: model.mode,
languages: model.languages.iter().map(|s| s.to_string()).collect(),
cached: false,
downloaded_bytes: 0,
file_size_mb: None,
}
}
Expand Down
138 changes: 114 additions & 24 deletions openless-all/app/src-tauri/src/asr/local/sherpa_download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,31 @@ fn downloaded_release_archive_bytes(
archive: sherpa::SherpaReleaseArchive,
) -> u64 {
let dest = dir.join(archive.file_name);
let (extracted, extracted_complete) = extracted_release_archive_bytes(dir, model_alias);
if extracted_complete {
return extracted;
}
if let Ok(meta) = std::fs::metadata(&dest) {
return meta.len();
}
let partial = partial_actual_size(&dest.with_extension("partial"));
if partial > 0 {
return partial;
}
partial.max(extracted)
}

fn extracted_release_archive_bytes(dir: &Path, model_alias: &str) -> (u64, bool) {
if let Ok(files) = sherpa::required_files_for_alias(model_alias) {
let total: u64 = files
.iter()
.map(|f| path_size_recursive(&dir.join(f)))
.sum();
if total > 0 {
return total;
let mut total = 0;
let mut complete = true;
for file in files {
let path = dir.join(file);
total += path_size_recursive(&path);
if !sherpa::required_path_is_valid(model_alias, file, &path) {
complete = false;
}
}
return (total, complete);
}
0
(0, false)
}

fn path_size_recursive(path: &Path) -> u64 {
Expand Down Expand Up @@ -546,15 +554,26 @@ async fn run_release_archive_download(
},
);
});
let result = download_one(
&client,
archive.url,
&archive_path,
total_bytes,
Arc::clone(&cancel),
on_progress,
)
.await;
let archive_file = info
.files
.first()
.ok_or_else(|| anyhow::anyhow!("release archive file info missing"))?;
let result = if archive_file_is_verified(&archive_path, archive_file) {
Ok(())
} else {
if archive_path.exists() {
remove_path_if_exists(&archive_path)?;
}
download_one(
&client,
archive.url,
&archive_path,
total_bytes,
Arc::clone(&cancel),
on_progress,
)
.await
};
if cancel.load(Ordering::SeqCst) {
emit_cancelled(app, model_alias, file_count, total_bytes);
return Ok(());
Expand All @@ -563,10 +582,14 @@ async fn run_release_archive_download(
emit_failed(app, model_alias, file_count, total_bytes, &error);
return Err(error);
}
if let Err(error) = verify_file(&archive_path, archive_file) {
emit_failed(app, model_alias, file_count, total_bytes, &error);
return Err(error);
}
let archive_path_for_extract = archive_path.clone();
let dir_for_extract = dir.to_path_buf();
let model_alias_for_extract = model_alias.to_string();
tauri::async_runtime::spawn_blocking(move || {
let extract_result = tauri::async_runtime::spawn_blocking(move || {
extract_release_archive(
&archive_path_for_extract,
&dir_for_extract,
Expand All @@ -575,16 +598,23 @@ async fn run_release_archive_download(
)
})
.await
.map_err(|error| anyhow::anyhow!("extract join failed: {error:#}"))??;
.map_err(|error| anyhow::anyhow!("extract join failed: {error:#}"))
.and_then(|result| result);
if let Err(error) = extract_result {
emit_failed(app, model_alias, file_count, total_bytes, &error);
return Err(error);
}
let cached_bytes = downloaded_bytes(model_alias);
let finished_total_bytes = total_bytes.max(cached_bytes);
emit(
app,
DownloadProgress {
model_id: model_alias.to_string(),
file: String::new(),
file_index: file_count,
file_count,
bytes_downloaded: total_bytes,
bytes_total: total_bytes,
bytes_downloaded: cached_bytes,
bytes_total: finished_total_bytes,
phase: DownloadPhase::Finished,
error: None,
},
Expand Down Expand Up @@ -612,7 +642,14 @@ fn extract_release_archive(
if !root.exists() {
anyhow::bail!("archive root missing: {}", root.display());
}
for required in sherpa::required_files_for_alias(model_alias)? {
let required_files = sherpa::required_files_for_alias(model_alias)?;
for required in required_files {
let src = root.join(required);
if !sherpa::required_path_is_valid(model_alias, required, &src) {
anyhow::bail!("archive required path missing: {}", src.display());
}
}
for required in required_files {
let src = root.join(required);
let dest = dir.join(required);
move_path(&src, &dest)?;
Expand Down Expand Up @@ -698,6 +735,10 @@ fn file_is_verified(path: &Path, file: &SherpaRemoteFile) -> bool {
path.exists() && verify_file(path, file).is_ok()
}

fn archive_file_is_verified(path: &Path, file: &SherpaRemoteFile) -> bool {
path.exists() && (file.size > 0 || file.sha256.is_some()) && verify_file(path, file).is_ok()
}

fn verify_file(path: &Path, file: &SherpaRemoteFile) -> Result<()> {
let meta =
std::fs::metadata(path).with_context(|| format!("stat failed: {}", path.display()))?;
Expand Down Expand Up @@ -786,3 +827,52 @@ fn emit_failed(
},
);
}

#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::path::{Path, PathBuf};

struct TempModelDir(PathBuf);

impl TempModelDir {
fn new(label: &str) -> Self {
let path = std::env::temp_dir().join(format!(
"openless-sherpa-download-{label}-{}",
uuid::Uuid::new_v4()
));
fs::create_dir_all(&path).expect("create temp model dir");
Self(path)
}

fn path(&self) -> &Path {
&self.0
}
}

impl Drop for TempModelDir {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.0);
}
}

#[test]
fn release_archive_downloaded_bytes_uses_extracted_assets_after_archive_removed() {
let alias = "qwen3-asr-0.6b-int8";
let archive = sherpa::release_archive_for_alias(alias).expect("release archive");
let dir = TempModelDir::new("release-archive-extracted");
fs::write(dir.path().join("conv_frontend.onnx"), b"abc").expect("write conv frontend");
fs::write(dir.path().join("encoder.int8.onnx"), b"encod").expect("write encoder");
fs::write(dir.path().join("decoder.int8.onnx"), b"decoder").expect("write decoder");
fs::create_dir_all(dir.path().join("tokenizer")).expect("create tokenizer dir");
fs::write(dir.path().join("tokenizer").join("tokenizer.json"), b"tok")
.expect("write tokenizer file");

assert!(!dir.path().join(archive.file_name).exists());
assert_eq!(
downloaded_release_archive_bytes(dir.path(), alias, archive),
18
);
}
}
11 changes: 9 additions & 2 deletions openless-all/app/src-tauri/src/asr/local/sherpa_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,15 @@ impl SherpaOnnxRuntime {
for model in &mut catalog {
let dir = sherpa::model_dir_for_alias(&model.alias)?;
model.cached = sherpa::required_files_for_alias(&model.alias)
.map(|files| files.iter().all(|file| dir.join(file).exists()))
.map(|files| {
files.iter().all(|file| {
let path = dir.join(file);
sherpa::required_path_is_valid(&model.alias, file, &path)
})
})
.unwrap_or(false);
model.downloaded_bytes =
crate::asr::local::sherpa_download::downloaded_bytes(&model.alias);
model.file_size_mb = model_dir_size_mb(&dir);
}
Ok(catalog)
Expand Down Expand Up @@ -256,7 +263,7 @@ fn validate_alias(alias: &str) -> Result<()> {
fn ensure_required_files(alias: &str, dir: &Path) -> Result<()> {
for file in sherpa::required_files_for_alias(alias)? {
let path = dir.join(file);
if !path.exists() {
if !sherpa::required_path_is_valid(alias, file, &path) {
anyhow::bail!(
"sherpa-onnx model file missing: {}. Place model files under {}",
file,
Expand Down
29 changes: 28 additions & 1 deletion openless-all/app/src-tauri/src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ pub fn get_default_style_system_prompts() -> StyleSystemPrompts {
trait SettingsWriter {
fn read_settings(&self) -> UserPreferences;
fn write_settings(&self, prefs: UserPreferences) -> Result<(), String>;
fn sync_active_asr_provider(&self, provider: &str) -> Result<(), String>;
fn refresh_dictation_hotkey(&self);
fn refresh_qa_hotkey(&self);
fn refresh_combo_hotkey(&self);
Expand All @@ -94,6 +95,10 @@ impl SettingsWriter for Coordinator {
self.prefs().set(prefs).map_err(|e| e.to_string())
}

fn sync_active_asr_provider(&self, provider: &str) -> Result<(), String> {
self.sync_active_asr_provider_to_vault(provider)
}

fn refresh_dictation_hotkey(&self) {
self.update_hotkey_binding();
}
Expand Down Expand Up @@ -128,6 +133,10 @@ impl<T: SettingsWriter + ?Sized> SettingsWriter for Arc<T> {
(**self).write_settings(prefs)
}

fn sync_active_asr_provider(&self, provider: &str) -> Result<(), String> {
(**self).sync_active_asr_provider(provider)
}

fn refresh_dictation_hotkey(&self) {
(**self).refresh_dictation_hotkey();
}
Expand Down Expand Up @@ -167,7 +176,12 @@ fn persist_settings<T: SettingsWriter>(
let translation_changed = previous.translation_hotkey != prefs.translation_hotkey;
let switch_style_changed = previous.switch_style_hotkey != prefs.switch_style_hotkey;
let open_app_changed = previous.open_app_hotkey != prefs.open_app_hotkey;
let active_asr_provider_changed = previous.active_asr_provider != prefs.active_asr_provider;
let active_asr_provider = prefs.active_asr_provider.clone();
coord.write_settings(prefs)?;
if active_asr_provider_changed {
coord.sync_active_asr_provider(&active_asr_provider)?;
}
if dictation_shortcut_changed || dictation_mode_changed {
coord.refresh_dictation_hotkey();
}
Expand Down Expand Up @@ -3290,6 +3304,7 @@ mod tests {
#[derive(Default)]
struct FakeSettingsWriter {
saved: Mutex<Option<UserPreferences>>,
active_asr_provider_syncs: Mutex<Vec<String>>,
dictation_refreshes: Mutex<u32>,
qa_refreshes: Mutex<u32>,
combo_refreshes: Mutex<u32>,
Expand Down Expand Up @@ -3561,6 +3576,14 @@ mod tests {
Ok(())
}

fn sync_active_asr_provider(&self, provider: &str) -> Result<(), String> {
self.active_asr_provider_syncs
.lock()
.unwrap()
.push(provider.to_string());
Ok(())
}

fn refresh_dictation_hotkey(&self) {
*self.dictation_refreshes.lock().unwrap() += 1;
}
Expand Down Expand Up @@ -3749,7 +3772,7 @@ mod tests {
}

#[test]
fn persist_settings_skips_hotkey_refresh_when_shortcuts_unchanged() {
fn persist_settings_syncs_active_asr_provider_without_hotkey_refresh() {
let writer = FakeSettingsWriter::default();
let previous = UserPreferences::default();
*writer.saved.lock().unwrap() = Some(previous.clone());
Expand All @@ -3775,6 +3798,10 @@ mod tests {
.expect("settings saved");
assert_eq!(saved.active_asr_provider, prefs.active_asr_provider);
assert_eq!(saved.microphone_device_name, prefs.microphone_device_name);
assert_eq!(
writer.active_asr_provider_syncs.lock().unwrap().clone(),
vec![prefs.active_asr_provider.clone()]
);
assert_eq!(*writer.dictation_refreshes.lock().unwrap(), 0);
assert_eq!(*writer.combo_refreshes.lock().unwrap(), 0);
assert_eq!(*writer.qa_refreshes.lock().unwrap(), 0);
Expand Down
10 changes: 10 additions & 0 deletions openless-all/app/src-tauri/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,16 @@ impl Coordinator {
pub fn prefs(&self) -> &PreferencesStore {
&self.inner.prefs
}
pub fn sync_active_asr_provider_from_preferences(&self) -> Result<(), String> {
let provider = self.inner.prefs.get().active_asr_provider;
self.sync_active_asr_provider_to_vault(&provider)
}
pub fn sync_active_asr_provider_to_vault(&self, provider: &str) -> Result<(), String> {
if CredentialsVault::get_active_asr() == provider {
return Ok(());
}
CredentialsVault::set_active_asr_provider(provider).map_err(|e| e.to_string())
}
pub fn style_packs(&self) -> &StylePackStore {
&self.inner.style_packs
}
Expand Down
4 changes: 4 additions & 0 deletions openless-all/app/src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ pub fn run() {
));
#[cfg(not(target_os = "windows"))]
let coordinator = Arc::new(coordinator::Coordinator::new());
#[cfg(target_os = "windows")]
if let Err(error) = coordinator.sync_active_asr_provider_from_preferences() {
log::warn!("[startup] sync active ASR provider from preferences failed: {error}");
}
let local_asr_download_manager = Arc::new(asr::local::DownloadManager::new());

tauri::Builder::default()
Expand Down
Loading