Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Loading Failed in Colab #159

Open
Catoverflow opened this issue Jul 8, 2024 · 5 comments
Open

Model Loading Failed in Colab #159

Catoverflow opened this issue Jul 8, 2024 · 5 comments

Comments

@Catoverflow
Copy link

Catoverflow commented Jul 8, 2024

I tried to run mt3 in colab, but it failed. I am not familiar with the DNN libraries so I'm posting steps to reproduce here only.

Steps to Reproduce

  1. Choose T4 GPU in runtime type
  2. Run the cell of Setup Environment
    It error with
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.17.0 which is incompatible.
  1. Run the cell of Import and Definitions
  2. Run the cell of Load Model with either mt3 or ismir2021, the notebook errored with:
---------------------------------------------------------------------------

XlaRuntimeError                           Traceback (most recent call last)

[<ipython-input-12-e1bcd991ed4d>](https://localhost:8080/#) in <cell line: 13>()
     11 
     12 log_event('loadModelStart', {'event_category': MODEL})
---> 13 inference_model = InferenceModel(checkpoint_path, MODEL)
     14 log_event('loadModelComplete', {'event_category': MODEL})

13 frames

[<ipython-input-11-30d8629039fb>](https://localhost:8080/#) in __init__(self, checkpoint_path, model_type)
     85 
     86     # Restore from checkpoint.
---> 87     self.restore_from_checkpoint(checkpoint_path)
     88 
     89   @property

[<ipython-input-11-30d8629039fb>](https://localhost:8080/#) in restore_from_checkpoint(self, checkpoint_path)
    120   def restore_from_checkpoint(self, checkpoint_path):
    121     """Restore training state from checkpoint, resets self._predict_fn()."""
--> 122     train_state_initializer = t5x.utils.TrainStateInitializer(
    123       optimizer_def=self.model.optimizer_def,
    124       init_fn=self.model.get_initial_variables,

[/usr/local/lib/python3.10/dist-packages/t5x/utils.py](https://localhost:8080/#) in __init__(self, optimizer_def, init_fn, input_shapes, partitioner, input_types)
   1057     self._partitioner = partitioner
   1058     self.global_train_state_shape = jax.eval_shape(
-> 1059         initialize_train_state, rng=jax.random.PRNGKey(0)
   1060     )
   1061     self.train_state_axes = partitioner.get_mesh_axes(

[/usr/local/lib/python3.10/dist-packages/jax/_src/random.py](https://localhost:8080/#) in PRNGKey(seed, impl)
    231     and ``fold_in``.
    232   """
--> 233   return _return_prng_keys(True, _key('PRNGKey', seed, impl))
    234 
    235 

[/usr/local/lib/python3.10/dist-packages/jax/_src/random.py](https://localhost:8080/#) in _key(ctor_name, seed, impl_spec)
    193         f"{ctor_name} accepts a scalar seed, but was given an array of "
    194         f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 195   return prng.random_seed(seed, impl=impl)
    196 
    197 def key(seed: int | ArrayLike, *,

[/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py](https://localhost:8080/#) in random_seed(seeds, impl)
    531   # use-case of instantiating with Python hashes in X32 mode.
    532   if isinstance(seeds, int):
--> 533     seeds_arr = jnp.asarray(np.int64(seeds))
    534   else:
    535     seeds_arr = jnp.asarray(seeds)

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in asarray(a, dtype, order, copy)
   3287   if dtype is not None:
   3288     dtype = dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True)  # type: ignore[assignment]
-> 3289   return array(a, dtype=dtype, copy=bool(copy), order=order)
   3290 
   3291 

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in array(object, dtype, copy, order, ndmin)
   3212     raise TypeError(f"Unexpected input type for array: {type(object)}")
   3213 
-> 3214   out_array: Array = lax_internal._convert_element_type(
   3215       out, dtype, weak_type=weak_type)
   3216   if ndmin > ndim(out_array):

[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in _convert_element_type(operand, new_dtype, weak_type)
    557     return type_cast(Array, operand)
    558   else:
--> 559     return convert_element_type_p.bind(operand, new_dtype=new_dtype,
    560                                        weak_type=bool(weak_type))
    561 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    414     assert (not config.enable_checks.value or
    415             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 416     return self.bind_with_trace(find_top_trace(args), args, params)
    417 
    418   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    418   def bind_with_trace(self, trace, args, params):
    419     with pop_level(trace.level):
--> 420       out = trace.process_primitive(self, map(trace.full_raise, args), params)
    421     return map(full_lower, out) if self.multiple_results else full_lower(out)
    422 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    919       return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    920     else:
--> 921       return primitive.impl(*tracers, **params)
    922 
    923   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
     85   prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     86   try:
---> 87     outs = fun(*args)
     88   finally:
     89     lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 15 frame]

[/usr/local/lib/python3.10/dist-packages/jax/_src/compiler.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
    236   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    237   # to take in `host_callbacks`
--> 238   return backend.compile(built_c, compile_options=options)
    239 
    240 def compile_or_get_cached(

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
@ntamotsu
Copy link

same issue

@goel-raghav
Copy link

same problem here

@goel-raghav
Copy link

seems like either changing to CPU or changing:

!python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
to
!python3 -m pip install nest-asyncio pyfluidsynth==1.3.0 -e .

fixes the problem. Not sure which one exactly because I ran out of colab GPU hours or something.

@Catoverflow
Copy link
Author

@goel-raghav Yes, after changing the code it says WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu., and the model loading succeeded.

@laqieer
Copy link

laqieer commented Jul 27, 2024

#160 works for GPU, which is much faster than CPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants