Creator of Keras confirmed that the new version comes out in a few days. Keras becomes multi-backend again with support for PyTorch, TensorFlow and JAX. Personally, I’m excited to be able to try JAX without having to deep dive into documentation and entire ecosystem. What about you?
As the others said, it’s a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it’s nn API, but personally I despise rewriting the amazing training loop for every implementation. That’s why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.