Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-344] Add more operators to onnx import #11856

Merged
merged 6 commits into from
Jul 27, 2018

Conversation

anirudhacharya
Copy link
Member

@anirudhacharya anirudhacharya commented Jul 23, 2018

Description

Add more operators to onnx import - Mean, Acos, Asin, Atan, Cos, Sin, Softplus, Tan, Shape, Gather. HardSigmoid, LpPool, GlobalLpPool, ReduceL1

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Add more operators to onnx import

Comments

@Roshrini
Copy link
Member

Thanks for working on this. LGTM

Copy link
Contributor

@sandeep-krishnamurthy sandeep-krishnamurthy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sandeep-krishnamurthy
Copy link
Contributor

@zhreshold - Can you please take a look once?

for op_input in inputs:
concat_input.append(symbol.expand_dims(op_input, axis=0))
concat_sym = symbol.concat(*concat_input, dim=0)
mean_sym = symbol.mean(concat_sym, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is mean on axis 0, are you sure it is the desired behavior?

Copy link
Member Author

@anirudhacharya anirudhacharya Jul 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is correct. We are doing an unsqueeze along axis=0, hence the mean along axis=0. It has also passed all the operator tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so it's elem-wise mean

@@ -348,6 +361,14 @@ def global_avgpooling(attrs, inputs, proto_obj):
'pool_type': 'avg'})
return 'Pooling', new_attrs, inputs

def global_lppooling(attrs, inputs, proto_obj):
"""Performs global lp pooling on the input."""
p_value = attrs['p'] if 'p' in attrs else 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p_value = attrs.get('p', 2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is lp pooling?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -505,6 +529,30 @@ def exponent(attrs, inputs, proto_obj):
"""Elementwise exponent of input array."""
return 'exp', attrs, inputs

def _cos(attrs, inputs, proto_obj):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why _cos, _sin for example need to start with leading underscore? in python they are under math namescope, so I guess there's no conflict

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there will be a namespace conflict if there is an import math statement. probably better practice to is the leading underscore for names that are very common and prevent namespace collision.

@@ -578,6 +637,20 @@ def avg_pooling(attrs, inputs, proto_obj):

return new_op, new_attrs, inputs

def lp_pooling(attrs, inputs, proto_obj):
"""LP Pooling"""
p_value = attrs['p'] if 'p' in attrs else 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use get rather than if else

"""Mean of two input tensors."""
concat_input = []
for op_input in inputs:
concat_input.append(symbol.expand_dims(op_input, axis=0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use list comprehensive is as simple as one line

concat_input = [symbol.expand_dims(op_input, axis=0) for op_input in inputs]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address this, the rest looks good to me now.

def global_lppooling(attrs, inputs, proto_obj):
"""Performs global lp pooling on the input."""
p_value = attrs.get('p', 2)
new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the indentation is bad here, two suggested way from PEP8 is

a = super_long_function_name(
    xxx, xxx, xxx, xxx, xx, xxx)

# or 
a = super_long_function_name(xxx, xxx, xxx,
                             xxx, xxx, xxx)

I suggest to use the first style here to avoid too many lines

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the one I have used is the second style and it has passed pylint and it is good code readability-wise. I would ideally want the dict to be printed as name, value pairs rather than as a single long list.

def lp_pooling(attrs, inputs, proto_obj):
"""LP Pooling"""
p_value = attrs.get('p', 2)
new_attrs = translation_utils._fix_attribute_names(attrs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here for the indentation

@zhreshold zhreshold merged commit 4bbf15c into apache:master Jul 27, 2018
@anirudhacharya anirudhacharya deleted the deco1 branch August 13, 2018 19:59
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* add more ops

* use dict.get

* add list comprehensive

* retrigger CI due to unrelated flaky test failure
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants