1use std::io;
2use std::io::Write;
3use std::path::Path;
4
5use anyhow::{Context, Result};
6use fs_err as fs;
7use serde::{Deserialize, Deserializer, Serialize};
8
9use crate::backends::Backend as BackendTrait;
10use crate::backends::local::LocalBackend;
11use crate::paths::{CONFIG_FILE_NAME, DEFAULT_FOLDER_NAME, find_repo_root};
12use crate::progress::ProgressReader;
13use crate::utils::parse_size;
14
15const DEFAULT_PROGRESS_BYTE_SIZE_THRESHOLD: u64 = 524_288_000;
16
17fn deserialize_size_option<'de, D: Deserializer<'de>>(
18 deserializer: D,
19) -> Result<Option<u64>, D::Error> {
20 let s: Option<String> = Option::deserialize(deserializer)?;
21 match s {
22 None => Ok(None),
23 Some(s) => parse_size(&s).map(Some).map_err(serde::de::Error::custom),
24 }
25}
26
27#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Default)]
28#[serde(rename_all = "lowercase")]
29pub enum Compression {
30 None,
31 #[default]
32 Zstd,
33}
34
35impl std::fmt::Display for Compression {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Compression::None => write!(f, "none"),
39 Compression::Zstd => write!(f, "zstd"),
40 }
41 }
42}
43
44impl Compression {
45 pub fn compress(
46 &self,
47 source: &Path,
48 dest: &Path,
49 on_bytes: Option<&(dyn Fn(u64) + Send + Sync)>,
50 ) -> Result<u64> {
51 match self {
52 Compression::None => {
53 if let Some(cb) = on_bytes {
54 let input = fs::File::open(source)?;
55 let output = fs::File::create(dest)?;
56 let mut reader = ProgressReader::new(input, cb);
57 let mut writer = io::BufWriter::new(output);
58 let bytes = io::copy(&mut reader, &mut writer)?;
59 writer.flush()?;
60 Ok(bytes)
61 } else {
62 let bytes = fs::copy(source, dest)?;
63 Ok(bytes)
64 }
65 }
66 Compression::Zstd => {
67 let input = fs::File::open(source)?;
68 let output = fs::File::create(dest)?;
69
70 if let Some(cb) = on_bytes {
71 let tracked = ProgressReader::new(input, cb);
72 let mut encoder = zstd::stream::read::Encoder::new(tracked, 0)?;
73 let mut writer = io::BufWriter::new(output);
74 let bytes = io::copy(&mut encoder, &mut writer)?;
75 writer.flush()?;
76 Ok(bytes)
77 } else {
78 let mut encoder = zstd::stream::read::Encoder::new(input, 0)?;
79 let mut writer = io::BufWriter::new(output);
80 let bytes = io::copy(&mut encoder, &mut writer)?;
81 writer.flush()?;
82 Ok(bytes)
83 }
84 }
85 }
86 }
87
88 pub fn decompress(
89 &self,
90 source: &Path,
91 dest: &Path,
92 on_bytes: Option<&(dyn Fn(u64) + Send + Sync)>,
93 ) -> Result<()> {
94 match self {
95 Compression::None => {
96 if let Some(cb) = on_bytes {
97 let input = fs::File::open(source)?;
98 let output = fs::File::create(dest)?;
99 let mut reader = ProgressReader::new(input, cb);
100 let mut writer = io::BufWriter::new(output);
101 io::copy(&mut reader, &mut writer)?;
102 writer.flush()?;
103 } else {
104 fs::copy(source, dest)?;
105 }
106 Ok(())
107 }
108 Compression::Zstd => {
109 let input = fs::File::open(source)?;
110 let output = fs::File::create(dest)?;
111
112 let mut decoder = zstd::stream::read::Decoder::new(input)?;
113 let mut writer = io::BufWriter::new(output);
114 if let Some(cb) = on_bytes {
115 let mut reader = ProgressReader::new(&mut decoder, cb);
116 io::copy(&mut reader, &mut writer)?;
117 } else {
118 io::copy(&mut decoder, &mut writer)?;
119 }
120 writer.flush()?;
121 Ok(())
122 }
123 }
124 }
125}
126
127#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
128#[serde(untagged)]
129pub enum Backend {
130 Local(LocalBackend),
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
134pub struct CliConfig {
135 #[serde(
137 default,
138 skip_serializing_if = "Option::is_none",
139 deserialize_with = "deserialize_size_option"
140 )]
141 progress_threshold: Option<u64>,
142}
143
144#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
145pub struct Config {
146 compression: Compression,
148 metadata_folder_name: Option<String>,
152 backend: Backend,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
154 cli: Option<CliConfig>,
155}
156
157impl Config {
158 pub fn new_local(path: impl AsRef<Path>, group: Option<String>) -> Result<Config> {
159 let backend = LocalBackend::new(path.as_ref(), group)?;
160 Ok(Config {
161 compression: Compression::Zstd,
162 metadata_folder_name: None,
163 backend: Backend::Local(backend),
164 cli: None,
165 })
166 }
167
168 pub fn save(&self, directory: impl AsRef<Path>) -> Result<()> {
169 let config_path = directory.as_ref().join(CONFIG_FILE_NAME);
170 let content = toml::to_string_pretty(&self)?;
171 fs::write(&config_path, content)?;
172 log::info!("Configuration saved to {}", config_path.display());
173 Ok(())
174 }
175
176 pub fn find(current_directory: impl AsRef<Path>) -> Option<Result<Self>> {
177 let repo_root = find_repo_root(current_directory);
178 let config_path = repo_root.join(CONFIG_FILE_NAME);
179 log::debug!("Looking for config at {}", config_path.display());
180 if config_path.exists() {
181 let content = match fs::read_to_string(&config_path) {
182 Ok(c) => c,
183 Err(e) => return Some(Err(e.into())),
184 };
185 Some(
186 toml::from_str(&content)
187 .with_context(|| format!("Failed to parse {}", config_path.display())),
188 )
189 } else {
190 log::debug!("No config file found at {}", config_path.display());
191 None
192 }
193 }
194
195 pub fn set_metadata_folder_name(&mut self, name: String) {
196 self.metadata_folder_name = Some(name);
197 }
198
199 pub fn metadata_folder_name(&self) -> &str {
200 if let Some(name) = &self.metadata_folder_name {
201 name.as_str()
202 } else {
203 DEFAULT_FOLDER_NAME
204 }
205 }
206
207 pub fn compression(&self) -> Compression {
208 self.compression
209 }
210
211 pub fn set_compression(&mut self, compression: Compression) {
212 self.compression = compression;
213 }
214
215 pub fn backend(&self) -> &dyn BackendTrait {
216 match &self.backend {
217 Backend::Local(b) => b,
218 }
219 }
220
221 pub fn progress_bytes_threshold(&self) -> u64 {
222 self.cli
223 .as_ref()
224 .and_then(|x| x.progress_threshold)
225 .unwrap_or(DEFAULT_PROGRESS_BYTE_SIZE_THRESHOLD)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::testutil::create_temp_git_repo;
233
234 #[test]
235 fn config_save_and_find_roundtrip() {
236 let (_tmp, root) = create_temp_git_repo();
237 let storage = root.join(".storage");
238
239 let original = Config::new_local(&storage, None).unwrap();
240 original.save(&root).unwrap();
241
242 let loaded = Config::find(&root).unwrap().unwrap();
243 assert_eq!(original, loaded);
244 }
245
246 #[test]
247 fn config_find_returns_none_without_config_file() {
248 let (_tmp, root) = create_temp_git_repo();
249 assert!(Config::find(&root).is_none());
250 }
251
252 #[cfg(unix)]
253 #[test]
254 fn new_local_validates_group_exists() {
255 let tmp = tempfile::tempdir().unwrap();
256 let storage = tmp.path().join(".storage");
257
258 let result = Config::new_local(&storage, Some("nonexistent_group_12345".to_string()));
260 assert!(result.is_err());
261 assert!(result.unwrap_err().to_string().contains("not found"));
262 }
263
264 #[test]
265 fn config_with_custom_metadata_folder() {
266 let (_tmp, root) = create_temp_git_repo();
267 let storage = root.join(".storage");
268
269 let mut config = Config::new_local(&storage, None).unwrap();
270 config.set_metadata_folder_name(".custom_dvs".to_string());
271 config.save(&root).unwrap();
272
273 let loaded = Config::find(&root).unwrap().unwrap();
274 assert_eq!(loaded.metadata_folder_name(), ".custom_dvs");
275 }
276}