• CampAny9995@alien.topB
    link
    fedilink
    English
    arrow-up
    1
    ·
    1 year ago

    My experience is that JAX is much lower level, and doesn’t come with batteries included so you have to pick your own optimization library or module abstraction. But I also find it makes way more sense than PyTorch (‘requires_gradient’?), and JAX’s autograd algorithm is substantially better thought out and more robust than PyTorch’s (my background was in compilers and autograd before moving into deep learning during postdocs, so I have dug into that side of things). Plus the support for TPUs makes life a bit easier compared to competing for instances on AWS.

    • Due-Wall-915@alien.topB
      link
      fedilink
      English
      arrow-up
      1
      ·
      1 year ago

      It’s a drop in replacement for numpy. It does not get sexy than that. I use it for my research on PDE solvers and deep learning and to be able to just use numpy and with automatic differentiation on it is very useful. Previously I was looking to use auto diff frameworks like tapenade but that’s not required anymore.