-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Major modeling refactoring #165
Conversation
Talking to some, it seems that the naming "Transductive" instead of "Search", since search is too broad in scope and the line is a bit blurred in what each algorithm specifically does. Transductive means "directly optimize the parameters specifically for an instance" which conveys the meaning more easily! |
Yep! I remember you mentioned this before, and that was what I used :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! I included a few comments and suggestions, but nothing mayor or important :)
Co-authored-by: ahottung <[email protected]>
I noticed doing the metaclasses that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NAR refactoring needed @Furffico
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Leaving some minor comments there.
return entropy | ||
|
||
|
||
# TODO: modularize inside the envs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means to add a num_starts
paragram in the init td
from environments right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, theoretically, it can be obtained through the environments
```{eval-rst} | ||
.. tip:: | ||
Note that in RL4CO we distinguish the RL algorithms and the actors via the following naming: | ||
|
||
* **Model:** Refers to the reinforcement learning algorithm encapsulated within a `LightningModule`. This module is responsible for training the policy. | ||
* **Policy:** Implemented as a `nn.Module`, this neural network (often referred to as the *actor*) takes an instance and outputs a sequence of actions, :math:`\pi = \pi_0, \pi_1, \dots, \pi_N`, which constitutes the solution. | ||
|
||
Here, :math:`\pi_i` represents the action taken at step :math:`i`, forming a sequence that leads to the optimal or near-optimal solution for the given instance. | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could mention here or somewhere else that abstract classes under rl4co/models/common
are not expected to be directly initialized. For example, if you want to use an autoregressive policy, you may want to init an AM model instead of the AutoregressivePolicy()
, same as NAR, improvement, and transductive classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Nice documentation.
Important Thanks for your revisions! We are planning to merge the PR into |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great job on the refactoring! I only have one minor comment regarding the configuration of different model-policy combinations. Maybe we can add the example to the hydra tutorial
configs/model/am-ppo.yaml
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regarding different algorithm-architecture combinations, it might be better to configure those combinations using hydra. In fact, using hydras nested instantiation we can already do something like this.
# @package _global_
model:
_target_: rl4co.models.PPO
policy:
_target_: rl4co.models.AttentionModelPolicy
env_name: ${env.name}
ppo_epochs: 4
metrics:
train: ["loss", "reward", "surrogate_loss", "value_loss", "entropy_bonus"]
Might be beneficial to note this in the docs / examples
Description
This PR is for a major, long-due refactoring to the RL4CO codebase 😄
Motivation and Context
So far, we had mostly overfitted RL4CO to the autoregressive Attention Model structure (encoder-decoder). However, there are several models that do not necessarily follow this, such as DeepACO. Implementing such a model requires changes in the structure, which then starts to become non-standardized anymore, and it could be hard for newcomers to implement a different model type. For this reason, some rethinking of the library on the modeling side is necessary!
Tip
Note that in RL4CO we refer to model as the RL algorithm and policy as the neural network that given an instance gives back a sequence of actions$\pi_0, \pi_1, \dots, \pi_N$ ., i.e. the solution. In other words: model is a
LightningModule
that trains the policy which is ann.Module
.New structure
With the new structure, the aim is to categorize NCO approaches (which are not necessarily trained with RL!) into the following: 1) constructive, 2) improvement, 3) transductive.
1) Constructive (policy)
Constructive NCO pre-train a policy to amortize the inference. "Constructive" means that a solution is created from scratch by the model. We can also categorize constructive NCO in two sub-categories depending on the role of encoder and decoder:
1a) Autoregressive (AR)
Autoregressive approaches use a decoder that outputs log probabilities for the current solution. These approaches generate a solution step by step, similar to e.g. LLMs. They have an encoder-decoder structure (i.e. AM). Some models may not have an encoder at all and just re-encode at each step (e.g. BQ-NCO).
1b) NonAutoregressive (NAR)
The difference between AR and NAR approaches is that NAR only use an encoder (they just encode in one shot) and generate for example a heatmap, which can then be decoded simply by using it as a probability distribution or by using some search method on top (e.g. DeepACO).
2) Improvement (policy)
These methods differ w.r.t. constructive NCO since they can obtain better solutions similarly to how local search algorithms work - they can improve the solutions over time. This is different from decoding strategies or similar in constructive methods since these policies are trained for performing improvement operations.
Note: You may have a look here for the basic constructive NCO policy structure! ;)
3) Transductive (model)
Tip
Read the definition of inductive vs transductive RL. In inductive RL, we train to generalize to new instances. In transductive RL we train (or finetune) to solve only specific ones.
Transductive models are learning algorithms that optimize on a specific instance: they improve solutions by updating policy parameters $\theta$_, which means that we are running optimization (backprop) during online testing. Transductive learning can be performed with different policies: for example EAS updates (a part of) AR policies parameters to obtain better solutions, but I guess there are ways (or papers out there I don't know of) that optimize at test time.
In practice, here is what the structure looks right now:
Changelog
searchtransductive*!embedding_dim
->embed_dim
(see PyTorchenv_name
as a mandatory parameterevaluate
which simply takes in an action if provided and gets it log probsevaluate_action
since it can be simply done via the above!Types of changes
TODO
Extra
policy.encoder
+value_head
(this way any model should be able to have a critic)Special thanks to @LTluttmann for your help and feedback~
Do you have some ideas / feedback on the above PR?
CC: @Furffico @henry-yeh @ahottung @bokveizen
Also tagging @yining043 for the coming improvement methods