Skip to main content
← dvs documentation Rust API reference

dvs/
config.rs

1use std::io;
2use std::io::Write;
3use std::path::Path;
4
5use anyhow::{Context, Result};
6use fs_err as fs;
7use serde::{Deserialize, Deserializer, Serialize};
8
9use crate::backends::Backend as BackendTrait;
10use crate::backends::local::LocalBackend;
11use crate::paths::{CONFIG_FILE_NAME, DEFAULT_FOLDER_NAME, find_repo_root};
12use crate::progress::ProgressReader;
13use crate::utils::parse_size;
14
15const DEFAULT_PROGRESS_BYTE_SIZE_THRESHOLD: u64 = 524_288_000;
16
17fn deserialize_size_option<'de, D: Deserializer<'de>>(
18    deserializer: D,
19) -> Result<Option<u64>, D::Error> {
20    let s: Option<String> = Option::deserialize(deserializer)?;
21    match s {
22        None => Ok(None),
23        Some(s) => parse_size(&s).map(Some).map_err(serde::de::Error::custom),
24    }
25}
26
27#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Default)]
28#[serde(rename_all = "lowercase")]
29pub enum Compression {
30    None,
31    #[default]
32    Zstd,
33}
34
35impl std::fmt::Display for Compression {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            Compression::None => write!(f, "none"),
39            Compression::Zstd => write!(f, "zstd"),
40        }
41    }
42}
43
44impl Compression {
45    pub fn compress(
46        &self,
47        source: &Path,
48        dest: &Path,
49        on_bytes: Option<&(dyn Fn(u64) + Send + Sync)>,
50    ) -> Result<u64> {
51        match self {
52            Compression::None => {
53                if let Some(cb) = on_bytes {
54                    let input = fs::File::open(source)?;
55                    let output = fs::File::create(dest)?;
56                    let mut reader = ProgressReader::new(input, cb);
57                    let mut writer = io::BufWriter::new(output);
58                    let bytes = io::copy(&mut reader, &mut writer)?;
59                    writer.flush()?;
60                    Ok(bytes)
61                } else {
62                    let bytes = fs::copy(source, dest)?;
63                    Ok(bytes)
64                }
65            }
66            Compression::Zstd => {
67                let input = fs::File::open(source)?;
68                let output = fs::File::create(dest)?;
69
70                if let Some(cb) = on_bytes {
71                    let tracked = ProgressReader::new(input, cb);
72                    let mut encoder = zstd::stream::read::Encoder::new(tracked, 0)?;
73                    let mut writer = io::BufWriter::new(output);
74                    let bytes = io::copy(&mut encoder, &mut writer)?;
75                    writer.flush()?;
76                    Ok(bytes)
77                } else {
78                    let mut encoder = zstd::stream::read::Encoder::new(input, 0)?;
79                    let mut writer = io::BufWriter::new(output);
80                    let bytes = io::copy(&mut encoder, &mut writer)?;
81                    writer.flush()?;
82                    Ok(bytes)
83                }
84            }
85        }
86    }
87
88    pub fn decompress(
89        &self,
90        source: &Path,
91        dest: &Path,
92        on_bytes: Option<&(dyn Fn(u64) + Send + Sync)>,
93    ) -> Result<()> {
94        match self {
95            Compression::None => {
96                if let Some(cb) = on_bytes {
97                    let input = fs::File::open(source)?;
98                    let output = fs::File::create(dest)?;
99                    let mut reader = ProgressReader::new(input, cb);
100                    let mut writer = io::BufWriter::new(output);
101                    io::copy(&mut reader, &mut writer)?;
102                    writer.flush()?;
103                } else {
104                    fs::copy(source, dest)?;
105                }
106                Ok(())
107            }
108            Compression::Zstd => {
109                let input = fs::File::open(source)?;
110                let output = fs::File::create(dest)?;
111
112                let mut decoder = zstd::stream::read::Decoder::new(input)?;
113                let mut writer = io::BufWriter::new(output);
114                if let Some(cb) = on_bytes {
115                    let mut reader = ProgressReader::new(&mut decoder, cb);
116                    io::copy(&mut reader, &mut writer)?;
117                } else {
118                    io::copy(&mut decoder, &mut writer)?;
119                }
120                writer.flush()?;
121                Ok(())
122            }
123        }
124    }
125}
126
127#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
128#[serde(untagged)]
129pub enum Backend {
130    Local(LocalBackend),
131}
132
133#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
134pub struct CliConfig {
135    /// Defaults to 500MB if not set in the config file
136    #[serde(
137        default,
138        skip_serializing_if = "Option::is_none",
139        deserialize_with = "deserialize_size_option"
140    )]
141    progress_threshold: Option<u64>,
142}
143
144#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
145pub struct Config {
146    /// Compression algorithm to use for files in the storage directory
147    compression: Compression,
148    /// By default, all the metadata files (the .dvs files) will be stored in a `.dvs` folder
149    /// at the root of the repository
150    /// If this option is set, dvs will use that folder name instead of `.dvs`
151    metadata_folder_name: Option<String>,
152    backend: Backend,
153    #[serde(default, skip_serializing_if = "Option::is_none")]
154    cli: Option<CliConfig>,
155}
156
157impl Config {
158    pub fn new_local(path: impl AsRef<Path>, group: Option<String>) -> Result<Config> {
159        let backend = LocalBackend::new(path.as_ref(), group)?;
160        Ok(Config {
161            compression: Compression::Zstd,
162            metadata_folder_name: None,
163            backend: Backend::Local(backend),
164            cli: None,
165        })
166    }
167
168    pub fn save(&self, directory: impl AsRef<Path>) -> Result<()> {
169        let config_path = directory.as_ref().join(CONFIG_FILE_NAME);
170        let content = toml::to_string_pretty(&self)?;
171        fs::write(&config_path, content)?;
172        log::info!("Configuration saved to {}", config_path.display());
173        Ok(())
174    }
175
176    pub fn find(current_directory: impl AsRef<Path>) -> Option<Result<Self>> {
177        let repo_root = find_repo_root(current_directory);
178        let config_path = repo_root.join(CONFIG_FILE_NAME);
179        log::debug!("Looking for config at {}", config_path.display());
180        if config_path.exists() {
181            let content = match fs::read_to_string(&config_path) {
182                Ok(c) => c,
183                Err(e) => return Some(Err(e.into())),
184            };
185            Some(
186                toml::from_str(&content)
187                    .with_context(|| format!("Failed to parse {}", config_path.display())),
188            )
189        } else {
190            log::debug!("No config file found at {}", config_path.display());
191            None
192        }
193    }
194
195    pub fn set_metadata_folder_name(&mut self, name: String) {
196        self.metadata_folder_name = Some(name);
197    }
198
199    pub fn metadata_folder_name(&self) -> &str {
200        if let Some(name) = &self.metadata_folder_name {
201            name.as_str()
202        } else {
203            DEFAULT_FOLDER_NAME
204        }
205    }
206
207    pub fn compression(&self) -> Compression {
208        self.compression
209    }
210
211    pub fn set_compression(&mut self, compression: Compression) {
212        self.compression = compression;
213    }
214
215    pub fn backend(&self) -> &dyn BackendTrait {
216        match &self.backend {
217            Backend::Local(b) => b,
218        }
219    }
220
221    pub fn progress_bytes_threshold(&self) -> u64 {
222        self.cli
223            .as_ref()
224            .and_then(|x| x.progress_threshold)
225            .unwrap_or(DEFAULT_PROGRESS_BYTE_SIZE_THRESHOLD)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::testutil::create_temp_git_repo;
233
234    #[test]
235    fn config_save_and_find_roundtrip() {
236        let (_tmp, root) = create_temp_git_repo();
237        let storage = root.join(".storage");
238
239        let original = Config::new_local(&storage, None).unwrap();
240        original.save(&root).unwrap();
241
242        let loaded = Config::find(&root).unwrap().unwrap();
243        assert_eq!(original, loaded);
244    }
245
246    #[test]
247    fn config_find_returns_none_without_config_file() {
248        let (_tmp, root) = create_temp_git_repo();
249        assert!(Config::find(&root).is_none());
250    }
251
252    #[cfg(unix)]
253    #[test]
254    fn new_local_validates_group_exists() {
255        let tmp = tempfile::tempdir().unwrap();
256        let storage = tmp.path().join(".storage");
257
258        // Non-existent group should fail
259        let result = Config::new_local(&storage, Some("nonexistent_group_12345".to_string()));
260        assert!(result.is_err());
261        assert!(result.unwrap_err().to_string().contains("not found"));
262    }
263
264    #[test]
265    fn config_with_custom_metadata_folder() {
266        let (_tmp, root) = create_temp_git_repo();
267        let storage = root.join(".storage");
268
269        let mut config = Config::new_local(&storage, None).unwrap();
270        config.set_metadata_folder_name(".custom_dvs".to_string());
271        config.save(&root).unwrap();
272
273        let loaded = Config::find(&root).unwrap().unwrap();
274        assert_eq!(loaded.metadata_folder_name(), ".custom_dvs");
275    }
276}