Skip to content

Commit

Permalink
Merge pull request tensorflow#396 from adamcrume/master
Browse files Browse the repository at this point in the history
Fix lints and update generated code
  • Loading branch information
adamcrume committed Feb 12, 2023
2 parents a7a4dad + 0557f8b commit 08bd27e
Show file tree
Hide file tree
Showing 12 changed files with 9,246 additions and 559 deletions.
3,291 changes: 3,291 additions & 0 deletions src/eager/op/raw_ops.rs

Large diffs are not rendered by default.

16 changes: 5 additions & 11 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ impl ImportGraphDefOptions {
/// unique. Defaults to false.
pub fn set_uniquify_names(&mut self, uniquify_names: bool) {
unsafe {
tf::TF_ImportGraphDefOptionsSetUniquifyNames(
self.inner,
if uniquify_names { 1 } else { 0 },
);
tf::TF_ImportGraphDefOptionsSetUniquifyNames(self.inner, u8::from(uniquify_names));
}
}

Expand All @@ -179,10 +176,7 @@ impl ImportGraphDefOptions {
/// treated as an error. This option has no effect if no prefix is specified.
pub fn set_uniquify_prefix(&mut self, uniquify_prefix: bool) {
unsafe {
tf::TF_ImportGraphDefOptionsSetUniquifyPrefix(
self.inner,
if uniquify_prefix { 1 } else { 0 },
);
tf::TF_ImportGraphDefOptionsSetUniquifyPrefix(self.inner, u8::from(uniquify_prefix));
}
}

Expand Down Expand Up @@ -693,7 +687,7 @@ impl Graph {
tf::TF_GraphToFunction(
self.inner(),
fn_name_cstr.as_ptr(),
if append_hash_to_fn_name { 1 } else { 0 },
u8::from(append_hash_to_fn_name),
num_opers,
c_opers_ptr,
c_inputs.len() as c_int,
Expand Down Expand Up @@ -2033,7 +2027,7 @@ impl<'a> OperationDescription<'a> {
) -> std::result::Result<(), NulError> {
let c_attr_name = CString::new(attr_name)?;
unsafe {
tf::TF_SetAttrBool(self.inner, c_attr_name.as_ptr(), if value { 1 } else { 0 });
tf::TF_SetAttrBool(self.inner, c_attr_name.as_ptr(), u8::from(value));
}
Ok(())
}
Expand All @@ -2045,7 +2039,7 @@ impl<'a> OperationDescription<'a> {
value: &[bool],
) -> std::result::Result<(), NulError> {
let c_attr_name = CString::new(attr_name)?;
let c_value: Vec<c_uchar> = value.iter().map(|x| if *x { 1 } else { 0 }).collect();
let c_value: Vec<c_uchar> = value.iter().map(|x| u8::from(*x)).collect();
unsafe {
tf::TF_SetAttrBoolList(
self.inner,
Expand Down
Loading

0 comments on commit 08bd27e

Please sign in to comment.