Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restructuring neural models + addition of OT-FM and GENOT #468

Merged
merged 193 commits into from
Apr 3, 2024

Conversation

MUCDK
Copy link
Contributor

@MUCDK MUCDK commented Nov 22, 2023

This is a PR for

  1. new base classes neural solvers and models (i.e. neural networks)
  2. Incorporating unbalancedness and learning the rescaling factors for any neural OT model.
  3. adding OTFlowMatching and, related to this, classes for flows and time samplers
  4. adding GENOT (with extension to conditional GENOT)
  5. adding drafts of data loaders.

Following this PR, the implementations of ICNN-based solvers and the Monge Gap model should be adapted and extended to the unbalanced setting.

Moreover, wrt typing, I replaced jnp.ndarray by jax.Array

What remains to be done, but I would prefer to do in a separate PR

  1. add graph costs to OTFM and GENOT, i.e. functions which compute batch-wise graphs, and compute costs, e.g. geodesic Sinkhorn or convolutional Wasserstein from this.
  2. implementations of ICNN-based solvers and the Monge Gap model should be adapted and extended to the unbalanced setting.

Copy link

codecov bot commented Nov 22, 2023

Codecov Report

Attention: Patch coverage is 91.97652% with 41 lines in your changes are missing coverage. Please review.

Project coverage is 90.71%. Comparing base (41906a2) to head (f227d54).
Report is 1 commits behind head on main.

❗ Current head f227d54 differs from pull request most recent head 6f9a77c. Consider uploading reports for the commit 6f9a77c to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #468      +/-   ##
==========================================
+ Coverage   90.64%   90.71%   +0.07%     
==========================================
  Files          60       68       +8     
  Lines        6682     7046     +364     
  Branches      956      996      +40     
==========================================
+ Hits         6057     6392     +335     
- Misses        477      494      +17     
- Partials      148      160      +12     
Files Coverage Δ
src/ott/datasets.py 95.00% <100.00%> (ø)
src/ott/initializers/neural/meta_initializer.py 92.42% <100.00%> (ø)
src/ott/neural/methods/monge_gap.py 91.20% <100.00%> (ø)
src/ott/neural/methods/neuraldual.py 57.14% <100.00%> (ø)
src/ott/neural/networks/layers/conjugate.py 100.00% <ø> (ø)
src/ott/neural/networks/layers/posdef.py 91.30% <ø> (ø)
src/ott/neural/networks/layers/time_encoder.py 100.00% <100.00%> (ø)
src/ott/neural/networks/velocity_field.py 100.00% <100.00%> (ø)
src/ott/problems/linear/potentials.py 91.44% <100.00%> (ø)
src/ott/solvers/linear/lineax_implicit.py 100.00% <100.00%> (ø)
... and 8 more

... and 1 file with indirect coverage changes

@MUCDK
Copy link
Contributor Author

MUCDK commented Nov 22, 2023

co-authored by @lucaeyring

@MUCDK MUCDK changed the title draft of BaseNeuralSolver and UnbalancednessMixin restructuring neural models + addition of OT-FM and GENOT Nov 23, 2023
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

return jnp.full_like(t, fill_value=self.sigma)


class BrownianNoiseFlow(StraightFlow):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"BrownianNoiseFlow" is ambiguous to me, it could correspond to Brownian motion and associated flow e.g. VE Flow in eqn 16 of https://arxiv.org/pdf/2210.02747.pdf

src/ott/neural/flow_models/genot.py Outdated Show resolved Hide resolved
src/ott/neural/flow_models/flows.py Outdated Show resolved Hide resolved
@michalk8
Copy link
Collaborator

michalk8 commented Mar 29, 2024

TODOs left:

  • check the docstring for wrong links
  • check imports in all the neural tutorials
  • update the documentation



@pytest.fixture()
def lin_cond_dl() -> DataLoader:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a "conditional OT data loader".

@michalk8 michalk8 self-requested a review April 3, 2024 13:00
Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MUCDK and everyone, LGTM, merging this!

@michalk8 michalk8 merged commit 713b9fc into ott-jax:main Apr 3, 2024
1 check passed
michalk8 added a commit that referenced this pull request Jun 27, 2024
* draft of BaseSolver and UnbalancedMixin

* draft of BaseSolver and UnbalancedMixin

* [ci skip] continue flow matching implementation

* [ci skip] continue flow matching implementation

* [ci skip] add neural networks

* [ci skip] add test

* [ci skip] resolve import errors

* [ci skip] MRO not working

* [ci skip] basic test for flow matching passes

* [ci skip] add tests for FM with conditions and conditional OT with FM

* [ci skip] add genot outline

* [ci skip] restructure genot

* [ci skip] restructure genot

* [ci skip] fix transport

* [ci skip] flow matching tests passing

* [ci skip] add more tests genot

* [ci skip] add more tests genot

* [ci skip] add TimeSampler

* [ci skip] add docs for TimeSampler and Flow

* [ci skip] add docs for OTFlowMatching and replace jnp.ndarray by jax.Array

* [ci skip] change init arguments of GENOT and add docstrings to GENOT

* [ci skip] split nets into base_models and models

* [ci skip] add references

* add tests for learning the rescaling factors

* [ci skip] partially fix rescaling factor learning

* [ci skip] fix rescaling factor learning

* [ci skip] all tests passing but k_samples_per_x in genot

* k_samples_per_x working in GENOT

* [ci skip] changed dataloaders to numpy and dict return

* [ci skip] changed dataloaders to numpy and dict return

* revert jax.Array to jnp.ndarray

* move dataloader from tests to module

* add docstrings to neurcal networks

* [ci skip] adapt type of scale_cost and cost_fn

* [ci skip] clean code

* [ci skip] fix genot tests

* [ci skip] fix otfm tests

* [ci skip] fix otfm tests

* add scale cost to otfm

* incorporate feedback partially

* resolve circular import errors

* resolve a few pre-commit errors

* resolve pre-commit errors

* resolve pre-commit errors

* fix rng bug

* Update pre-commit

* fix import error

* Run linter

* replace rng jnp.ndarray type by jax.array

* replace rng jnp.ndarray type by jax.array

* fix import error

* [ci skip] start to incorporate feedback

* restructure neural module

* fix import errors

* incorporate feedback partially

* make time encoder a layer

* make conditions Optional and minor feedback

* revert faulty jax.array / jnp.ndarray conversions

* make formatting in neural nets nicer

* add description to Velocity Field

* replace time sampler class by function

* add citations

* add more references

* rename keys_model to rng

* fix tests regarding time sampling

* fix typo in tests

* rename neural_vector_field to velocity_field everywhere

* fix OTFlowMatching.transport

* fix rescaling networks

* Update src/ott/neural/flows/flows.py

Co-authored-by: nvesseron <[email protected]>

* Update src/ott/neural/flows/flows.py

Co-authored-by: nvesseron <[email protected]>

* test for scale_cost

* update test for scale_cost

* fix bug for scale_cost

* fix bug for scale_cost

* jit solve_ode in genot

* incorporate changes partially

* [ci skip] intermediate save

* [ci skip] neural base solver update

* make resamlpemixin a class

* incorporate more changes

* move noise sampling to flows

* fix bug in passing rngs in otfm

* introduce otmatcher in otfm

* [ci skip] split GENOT into GENOTLin and GENOTQuad

* remove dictionaries in OTFM and GENOT classes

* change logic in match_latent_to_data in genot

* change data loaders / data sets

* finish data loader refactoring

* Update linter

* fix bug in _resample_data`

* incorporate more changes

* add docs

* incorporate more changes

* problem with custom type

* fix scale cost bug

* fix bugs

* fux bug in unbalancedness/rescalingMlp

* unify unbalancedness step in GENOT

* change OTDataSet and OTFlowMatching to 4 data loaderes

* Fix bug in the `ConditionalOTDataset`

* Polish docs in the `flows.py`

* Update `OTFM`

* Fix small bugs in `OTFM`

* Polish layers

* Fix typo in citation

* More polish for the docs

* remove print statements and unbalancednesshandler

* remove tests

* make genot training loops more similar to otfm training loop

* adapt tests to the extent possible

* Add weights to sampling

* Start cleaning matchers

* Add conditional sampling + resampling

* Add initial quad matcher

* Improve typing

* Remove `base_solver.py`

* Add TODO

* Update datasets, fix OTFM tests

* Start cleaning GENOT

* Update GENOT

* Remove old GENOTLin/GENOTQuad

* Remove axis swapping

* Remove old todo

* Fix OTFM tests

* Remove `MLPBlock` and `RescalingMLP`

* Add forgotten license

* Remove `__post_init__` from `VF`

* Move cyclical time encoder

* Move more stuff to `utils`

* Remove `samplers.py`

* Rename `cond_dim` -> `condition_dim`

* Nicer formatting

* Fix bug when sampling from the target

* Fix another bug when sampling from the data

* Add initial test for GW

* Remove old GENOT tests

* Remove old dataloaders

* Add more todos

* add docs to dataloader

* expose args in GENOT

* add docs and adapt data_match_fn

* fix linting

* fix data loading and add genot fused tests

* genot tests passing

* adapt docs

* adapt docs

* add error message

* clean docs

* comprise genot tests

* change reference for GENOT

* add missing docstring

* Modify behaviour of `ConditionalLoader`

* Update docstring

* Clean GENOT docs

* Improve VF

* Simplify GENOT test

* Better metadata wrapper in tests

* Fix condition in GENOT test

* Add quad cond dl

* Add conf fused DL

* Polish docs

* Remove conditional loader

* Fix link in the docs

* Improve VF

* Fix GENOT test

* Polish docs

* Remove `uniform_marginals` argument

* Fix undefined variable

* Update `GENOT.transport` docs

* Add `diffrax` to `conf.py`

* Restructure files

* Fix neural init tests import

* Update `docs/`

* Update Monge Gap

* Update MetaOT and NeuralDual

* Update ICNN inits

* Fix links to neural in the docs

* Check for condition dim in VF

* Don't use activation fn in the last layer of VF

* Update assertions

* Try skipping OTFM/GENOT tests temporarily

* Be extra verbose when intalling packages

* Remove `torch` dependency

* Remove `torch` from tests in `pyproject.toml`

* [ci skip] Update docstrings

---------

Co-authored-by: lucaeyring <[email protected]>
Co-authored-by: Michal Klein <[email protected]>
Co-authored-by: nvesseron <[email protected]>
Co-authored-by: Dominik Klein <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants