CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) timm
trained weights
Weights were created reproducing the paper architectures and exploring timm sepcific additions such as ConvNeXt blocks, parallel partitioning, and other experiments.
Weights were trained on a mix of TPU and GPU systems. Bulk of weights were trained on TPU via the TRC program (https://sites.research.google/trc/about/).
CoAtNet variants run particularly well on TPU, it's a great combination. MaxVit is better suited to GPU due to the window partitioning, although there are some optimizations that can be made to improve TPU padding/utilization incl using 256x256 image size (8, 8) windo/grid size, and keeping format in NCHW for partition attention when using PyTorch XLA.
Glossary:
coatnet
- CoAtNet (MBConv + transformer blocks)coatnext
- CoAtNet w/ ConvNeXt conv blocksmaxvit
- MaxViT (MBConv + block (ala swin) and grid partioning transformer blocks)maxxvit
- MaxViT w/ ConvNeXt conv blocksrmlp
- relative position embedding w/ MLP (can be resized) -- if this isn't in model name, it's using relative position bias (ala swin)rw
- my variations on the model, slight differences in sizing / pooling / etc from Google paper spec
Results:
maxvit_rmlp_pico_rw_256
- 80.5 @ 256, 81.3 @ 320 (T)coatnet_nano_rw_224
- 81.7 @ 224 (T)coatnext_nano_rw_224
- 82.0 @ 224 (G) -- (uses convnext block, no BatchNorm)coatnet_rmlp_nano_rw_224
- 82.0 @ 224, 82.8 @ 320 (T)coatnet_0_rw_224
- 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blockscoatnet_bn_0_rw_224
- 82.4 (T) -- all BatchNorm, no LayerNormmaxvit_nano_rw_256
- 82.9 @ 256 (T)maxvit_rmlp_nano_rw_256
- 83.0 @ 256, 83.6 @ 320 (T)maxxvit_rmlp_nano_rw_256
- 83.0 @ 256, 83.7 @ 320 (G) (uses convnext conv block, no BatchNorm)coatnet_rmlp_1_rw_224
- 83.4 @ 224, 84 @ 320 (T)maxvit_tiny_rw_224
- 83.5 @ 224 (G)coatnet_1_rw_224
- 83.6 @ 224 (G)maxvit_rmlp_tiny_rw_256
- 84.2 @ 256, 84.8 @ 320 (T)maxvit_rmlp_small_rw_224
- 84.5 @ 224, 85.1 @ 320 (G)maxxvit_rmlp_small_rw_256
- 84.6 @ 256, 84.9 @ 288 (G) -- could be trained better, hparms need tuning (uses convnext conv block, no BN)coatnet_rmlp_2_rw_224
- 84.6 @ 224, 85 @ 320 (T)
(T) = TPU trained with bits_and_tpu
branch training code, (G) = GPU trained