Working Notes: a commonplace notebook for recording & exploring ideas.
Home. Site Map. Subscribe. More at expLog.
— Kunal
Quick notes on jax
import jax.numpy as jnpjax.vmap use this to auto-batch over dimensions; can specify axesjax.shard_map to map over shards for manual parallelism