diff --git a/src/dict.rs b/src/dict.rs index c5160d7..1861757 100644 --- a/src/dict.rs +++ b/src/dict.rs @@ -2,6 +2,8 @@ use std::borrow::Cow; use unicase::UniCase; +use crate::tokens::Case; + #[derive(Default)] pub struct Dictionary {} @@ -18,7 +20,8 @@ impl Dictionary { } pub fn correct_word<'s, 'w>(&'s self, word: crate::tokens::Word<'w>) -> Option> { - map_lookup(&crate::dict_codegen::WORD_DICTIONARY, word.token()).map(|s| s.into()) + map_lookup(&crate::dict_codegen::WORD_DICTIONARY, word.token()) + .map(|s| case_correct(s, word.case())) } } @@ -37,3 +40,45 @@ fn map_lookup( map.get(&UniCase(key)).cloned() } } + +fn case_correct(correction: &str, case: Case) -> Cow<'_, str> { + match case { + Case::Lower | Case::None => correction.into(), + Case::Title => { + let mut title = String::with_capacity(correction.as_bytes().len()); + let mut char_indices = correction.char_indices(); + if let Some((_, c)) = char_indices.next() { + title.extend(c.to_uppercase()); + if let Some((i, _)) = char_indices.next() { + title.push_str(&correction[i..]); + } + } + title.into() + } + Case::Scream => correction + .chars() + .flat_map(|c| c.to_uppercase()) + .collect::() + .into(), + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_case_correct() { + let cases = [ + ("foo", Case::Lower, "foo"), + ("foo", Case::None, "foo"), + ("foo", Case::Title, "Foo"), + ("foo", Case::Scream, "FOO"), + ("fOo", Case::None, "fOo"), + ]; + for (correction, case, expected) in cases.iter() { + let actual = case_correct(correction, *case); + assert_eq!(*expected, actual); + } + } +}