Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative design for node, argument handling in burn import #1840

Closed
skewballfox opened this issue May 30, 2024 · 2 comments
Closed

Alternative design for node, argument handling in burn import #1840

skewballfox opened this issue May 30, 2024 · 2 comments
Labels
design Design related task onnx

Comments

@skewballfox
Copy link
Contributor

skewballfox commented May 30, 2024

The current design for handling of graph inputs and outputs leaves a bit to be desired in terms of complexity. I think it's important to consider the constraints. if there are constraints I haven't listed or that need to be changed, please let me know and I'll update them.

Implementation requirements

Required by spec or compiler

  • Nodes need to be able to reference the outputs of previous nodes or graph inputs, which were renamed at the (current) final step of node processing.
  • fusing nodes means (currently) looking one node ahead
  • Nodes need to mutate their output arguments
  • Nodes need to mutate their inputs (cast_update_outputs,transpose_linear_weights,check_constants)
  • the node being mutated can't be owned by the vector of previous nodes,
  • graph inputs, outputs need to be filtered (marked as passed). Node outputs do not.
  • graph inputs, outputs need to preserve ordering
  • Some nodes need to be removed either due to lifting or fusion

Required for code generation

while these won't lead to a bug, they will lead to arguably incorrect code generation:

  • argument names are renamed based off of the updated node names
  • node names are based off their type, which may change

Desired for performance

  • keep the implementation linear, ideally single pass over all nodes (save for the final filter operation)
  • keep cloning to what is strictly necessary
  • Avoid Vec<Rc<RefCell<Node>>>
  • minimize duplicated or wasted work

Desired for the API

  • limit implicit dependencies on the order of functions. Ideally it should be easy for new contributors to add extra functions to node_generation without something breaking due to those functions changes being overwritten
  • minimize the amount of arguments being passed between functions, or at least keep the arguments passed consistent.
  • avoid or hide chaining hashmaps (keys-> keys-> values).

My approach in #1795 is to completely separate nodes and arguments. Nodes store a key (which is the original, unchanged name), and updating an argument doesn't require any synchronization, but this leads to an undesirable api downstream

I had an idea for an alternative design where we could avoid separating nodes and arguments.

Alternative design

  1. Make a stateful proto conversion struct to replace the current convert_node_proto function, moving some of the stuff stored in OnnxGraphBuilder into that struct.
#[derive(Default)]
pub(crate) struct ProtoNodeConverter {
    /// a Map of the original Onnx names to the new (variable) names
    io_name_map: HashMap<String, String>,
    /// a counter used to generate unique names for nodes
    node_name_counter: HashMap<NodeType, usize>,
    ...
}
  1. rename nodes and arguments prior to returning to graph_builder.
  2. Store a new io_map in graph_builder. note the tuple listed below may need to be nested in another enum(to handle graph in/out):
io_map: HashMap<String, (usize, IOEntry)>, 
  1. use the vector of processed nodes much in same way graph_io is currently used, save mu

Advantages

  • all the implicitly ordered stuff would happen at the very beginning, all other ordering would be arbitrary
  • no need to have multiple node types, or to change the downstream api

Problems

input mutation

This pattern relied on mutability being required only for graph outputs, an input can be added, swapped or dropped, but not mutated. Turns out there are 3 functions (cast_update_outputs,transpose_linear_weights,check_constants) that mutate their inputs. check_constant isn't a big deal because the node it's referencing will be deleted.

for the other two, looking through the code on the main branch, I'm not seeing where the the changes to those mutated inputs were synced to the graph_io, so I'm a bit fuzzy as to whether those mutations actually happened to the original arguments, and if not how that mapped to a valid operation (the output of the previous node is correct as the input for the node in question)

arguably incorrect names

The second problem is some nodes will be remapped, and thus the names of the outputs will be invalid. less concerning is some constants will be removed, which will change the count. More concerning is you'll have some nodes (and arguments) named after the original type unless you either move anything that might remap to the conversion step (which sort of defeats the purpose), or find a way to split the logic for node transformations into parts for detection and transform.

Moving data from the functions/build steps that change the node to the proto_conversion while avoiding multiple mutable references to the same parent struct either involves passing extra arguments or returning some stuff and making node gen function slightly more complicated

We could drop this constraint, generating uuids for variable names, and this would make most of the steps independent. But that would almost certainly make debugging harder, and would make the generated code unreadable.

@antimora antimora added design Design related task onnx labels May 30, 2024
@antimora
Copy link
Collaborator

Related: #1812

@skewballfox
Copy link
Contributor Author

So I realized something about Node argument access that should have been obvious in hindsight. The node output(and graph_output) can be handled separately from graph input.

  1. the input arguments can be immediately remapped(names and values), as inputs will never be renamed
  2. the syncing of changes to the current node's output (to whatever is tracking arguments) can happen at the very end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Design related task onnx
Projects
None yet
Development

No branches or pull requests

2 participants