Skip to content

Commit

Permalink
ensure compatible version of jax is installed with cuda support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547565102
  • Loading branch information
iansimon authored and Magenta Team committed Jul 12, 2023
1 parent 4a1b1c4 commit 9c6250d
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions mt3/colab/music_transcription_with_transformers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "ibSG_uu0QXgc"
},
"outputs": [],
Expand All @@ -63,13 +64,10 @@
"\n",
"!apt-get update -qq \u0026\u0026 apt-get install -qq libfluidsynth2 build-essential libasound2-dev libjack-dev\n",
"\n",
"# upgrade jax with cuda drivers, otherwise t5x replaces it with non-cuda version\n",
"!pip install \"jax[cuda11_local]\u003e=0.4.10\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"# install mt3\n",
"!git clone --branch=main https://github.com/magenta/mt3\n",
"!mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp\n",
"!python3 -m pip install nest-asyncio pyfluidsynth==1.3.0 -e .\n",
"!python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
"\n",
"# copy checkpoints\n",
"!gsutil -q -m cp -r gs:https://mt3/checkpoints .\n",
Expand Down Expand Up @@ -191,7 +189,7 @@
"\n",
" self.batch_size = 8\n",
" self.outputs_length = 1024\n",
" self.sequence_length = {'inputs': self.inputs_length, \n",
" self.sequence_length = {'inputs': self.inputs_length,\n",
" 'targets': self.outputs_length}\n",
"\n",
" self.partitioner = t5x.partitioning.PjitPartitioner(\n",
Expand Down Expand Up @@ -284,7 +282,7 @@
"\n",
" def __call__(self, audio):\n",
" \"\"\"Infer note sequence from audio samples.\n",
" \n",
"\n",
" Args:\n",
" audio: 1-d numpy array of audio samples (16kHz) for a single example.\n",
"\n",
Expand Down Expand Up @@ -442,7 +440,7 @@
" if not note.is_drum))\n",
"})\n",
"\n",
"note_seq.play_sequence(est_ns, synth=note_seq.fluidsynth, \n",
"note_seq.play_sequence(est_ns, synth=note_seq.fluidsynth,\n",
" sample_rate=SAMPLE_RATE, sf2_path=SF2_PATH)\n",
"note_seq.plot_sequence(est_ns)"
]
Expand Down

0 comments on commit 9c6250d

Please sign in to comment.