Skip to content

Commit

Permalink
Merge pull request sharksforarms#335 from IniterWorker/feature/reader…
Browse files Browse the repository at this point in the history
…-catch-all

Add `default` attribute for DekuRead enums
  • Loading branch information
wcampbell0x2a authored May 19, 2023
2 parents dccf88a + 4c2a5d9 commit d045596
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 4 deletions.
8 changes: 8 additions & 0 deletions deku-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ struct VariantData {

/// variant `id_pat` value
id_pat: Option<TokenStream>,

/// variant `default` option
default: Option<bool>,
}

impl VariantData {
Expand All @@ -568,6 +571,7 @@ impl VariantData {
writer: receiver.writer?,
id: receiver.id,
id_pat: receiver.id_pat?,
default: receiver.default,
};

VariantData::validate(&ret)?;
Expand Down Expand Up @@ -843,6 +847,10 @@ struct DekuVariantReceiver {
/// variant `id_pat` value
#[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")]
id_pat: Result<Option<TokenStream>, ReplacementError>,

/// variant `id` value
#[darling(default)]
default: Option<bool>,
}

/// Entry function for `DekuRead` proc-macro
Expand Down
26 changes: 24 additions & 2 deletions deku-derive/src/macros/deku_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
let magic_read = emit_magic_read(input);

let mut has_default_match = false;
let mut default_reader = None;
let mut pre_match_tokens = Vec::with_capacity(variants.len());
let mut variant_matches = Vec::with_capacity(variants.len());
let mut deku_ids = Vec::with_capacity(variants.len());
Expand Down Expand Up @@ -198,6 +199,7 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {

let variant_ident = &variant.ident;
let variant_reader = &variant.reader;
let variant_has_default = variant.default.unwrap_or(false);

let variant_read_func = if variant_reader.is_some() {
quote! { #variant_reader; }
Expand All @@ -210,7 +212,6 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
.iter()
.filter(|f| !f.is_temp)
.map(|f| &f.field_ident);

let internal_fields = gen_internal_field_idents(variant_is_named, field_idents);
let initialize_enum =
super::gen_enum_init(variant_is_named, variant_ident, internal_fields);
Expand Down Expand Up @@ -243,6 +244,16 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
}
};

// register `default`
if default_reader.is_some() && variant_has_default {
return Err(syn::Error::new(
variant.ident.span(),
"DekuRead: `default` must be specified only once",
));
} else if default_reader.is_none() && variant_has_default {
default_reader = Some(variant_read_func.clone())
}

variant_matches.push(quote! {
#variant_id => {
#variant_read_func
Expand All @@ -251,7 +262,7 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
}

// if no default match, return error
if !has_default_match {
if !has_default_match && default_reader.is_none() {
variant_matches.push(quote! {
_ => {
return Err(::#crate_::DekuError::Parse(
Expand All @@ -265,6 +276,17 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
});
}

// if default
if !has_default_match {
if let Some(variant_read_func) = default_reader {
variant_matches.push(quote! {
_ => {
#variant_read_func
}
});
}
}

let variant_id_read = if id.is_some() {
quote! {
let (__deku_new_rest, __deku_variant_id) = (__deku_rest, (#id));
Expand Down
30 changes: 30 additions & 0 deletions examples/enums_catch_all.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use deku::prelude::*;
use hexlit::hex;
use std::convert::TryFrom;
use std::convert::TryInto;

#[derive(Clone, Copy, PartialEq, Eq, Debug, DekuWrite, DekuRead)]
#[deku(type = "u8")]
#[non_exhaustive]
#[repr(u8)]
pub enum DekuTest {
/// A
#[deku(id = "1")]
A = 0,
/// B
#[deku(id = "2")]
B = 1,
/// C
#[deku(id = "3", default)]
C = 2,
}

fn main() {
let input = hex!("0A").to_vec();
let output = hex!("03").to_vec();

let ret_read = DekuTest::try_from(input.as_slice()).unwrap();
assert_eq!(DekuTest::C, ret_read);
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
assert_eq!(output.to_vec(), ret_write);
}
7 changes: 5 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,17 @@ First the "type" is read using the `type`, then is matched against the
variants given `id`. What happens after is the same as structs!
This is implemented with the [id](/attributes/index.html#id),
[id_pat](/attributes/index.html#id_pat) and
[id_pat](/attributes/index.html#id_pat), [default](/attributes/index.html#default) and
[type](attributes#type) attributes. See these for more examples.
If no `id` is specified, the variant will default to it's discriminant value.
If no variant can be matched, a [DekuError::Parse](crate::error::DekuError)
If no variant can be matched and the `default` is not provided, a [DekuError::Parse](crate::error::DekuError)
error will be returned.
If no variant can be matched and the `default` is provided, a variant will be returned
based on the field marked with `default`.
Example:
```rust
Expand Down
86 changes: 86 additions & 0 deletions tests/test_catch_all.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#[cfg(test)]
mod test {
use deku::prelude::*;
use std::convert::TryFrom;
use std::convert::TryInto;

/// Basic test struct
#[derive(Clone, Copy, PartialEq, Eq, Debug, DekuWrite, DekuRead)]
#[deku(type = "u8")]
#[non_exhaustive]
#[repr(u8)]
pub enum BasicMapping {
/// A
A = 0,
/// B
B = 1,
/// C
#[deku(default)]
C = 2,
}

/// Advanced test struct
#[derive(Clone, Copy, PartialEq, Eq, Debug, DekuWrite, DekuRead)]
#[deku(type = "u8")]
#[non_exhaustive]
#[repr(u8)]
pub enum AdvancedRemapping {
/// A
#[deku(id = "1")]
A = 0,
/// B
#[deku(id = "2")]
B = 1,
/// C
#[deku(id = "3", default)]
C = 2,
}

#[test]
fn test_basic_a() {
let input = [0u8];
let ret_read = BasicMapping::try_from(input.as_slice()).unwrap();
assert_eq!(BasicMapping::A, ret_read);
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
assert_eq!(input.to_vec(), ret_write);
}

#[test]
fn test_basic_c() {
let input = [2u8];
let ret_read = BasicMapping::try_from(input.as_slice()).unwrap();
assert_eq!(BasicMapping::C, ret_read);
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
assert_eq!(input.to_vec(), ret_write);
}

#[test]
fn test_basic_pattern() {
let input = [10u8];
let output = [BasicMapping::C as u8];
let ret_read = BasicMapping::try_from(input.as_slice()).unwrap();
assert_eq!(BasicMapping::C, ret_read);
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
assert_eq!(output.to_vec(), ret_write);
}

#[test]
fn test_advanced_remapping() {
let input = [1u8];
let output = [1u8];
let ret_read = AdvancedRemapping::try_from(input.as_slice()).unwrap();
assert_eq!(AdvancedRemapping::A, ret_read);
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
assert_eq!(output.to_vec(), ret_write);
}

#[test]
fn test_advanced_remapping_default_field() {
let input = [10u8];
let output = [3u8];
let ret_read = AdvancedRemapping::try_from(input.as_slice()).unwrap();
assert_eq!(AdvancedRemapping::C, ret_read);
let ret_write: Vec<u8> = ret_read.try_into().unwrap();
assert_eq!(output.to_vec(), ret_write);
}
}
12 changes: 12 additions & 0 deletions tests/test_compile/cases/catch_all_multiple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use deku::prelude::*;

#[derive(DekuRead)]
#[deku(type = "u8")]
enum Test1 {
#[deku(default)]
A = 1,
#[deku(default)]
B = 2,
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/test_compile/cases/catch_all_multiple.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: DekuRead: `default` must be specified only once
--> tests/test_compile/cases/catch_all_multiple.rs:9:5
|
9 | B = 2,
| ^

0 comments on commit d045596

Please sign in to comment.