github allenai/allennlp v1.0.0

latest releases: v2.10.1, v2.10.0, v2.9.3...
3 years ago

The 1.0 version of AllenNLP is the culmination of more than 500 commits over the course of several months of work from our engineering team. The AllenNLP library has had wide-reaching appeal so far in its lifetime, and this 1.0 release represents an important maturity milestone. While we will continue to move fast to keep up with the ever-changing state of the art, we will be increasingly conscious of the effect future API changes have on our existing user base.

This release touches almost every aspect of the library, ranging from improving documentation to adding new natural-language processing components, to adjusting our APIs so they serve the community for the long haul. While we cannot summarize everything in these release notes, here are some of the main milestones for the 1.0 release.

  1. We are releasing several new models, such as:
    a. TransformerQA, a reading comprehension model (paper, demo)
    b. An improved coreference model, with a 17% absolute improvement (architecture paper/embedder paper, demo)
    c. The NMN reading comprehension model (paper, demo)
    d. The RoBERTa models for textual entailment, or NLI (paper, demo)

  2. We have new introductory material in the form of an interactive guide, showing how to use library components and our experiment framework. The guide's goal is to provide a comprehensive introduction to AllenNLP for people with a good understanding of machine learning, Python, and some PyTorch.

  3. We have improved performance across the library.
    a. Switching to native PyTorch data loading, which is not only much faster but also allows the three main parts of the library (data, model, and training) to interoperate with any native PyTorch code.
    b. Enabled support for 16-bit floating point through Apex.
    c. Multi-GPU training now utilizes a separate Python process for each GPU. These workers communicate using PyTorch's distributed module. This is more efficient than the old system which used a single Python process and was therefore limited by the GIL.

  4. We separated our models into a model repository (allennlp-models), so we have a lean core library with fewer dependencies.

  5. We dramatically simplified how AllenNLP code corresponds to AllenNLP configuration files, which also makes the library easy to use from raw Python.

But changes are not limited to these. Some other highlights are that we have:

  1. Support for gradient accumulation.
  2. Improved configurability of the trainer so you can inject your own call on each batch.
  3. Seamless support for using word-piece tokenization on pre-tokenized text.
  4. A sampler that creates batches with roughly equal numbers of tokens.
  5. Unified support for Huggingface's transformer library.
  6. Support for token type IDs throughout the library.
  7. Nightly releases of the library to pip.
  8. BLEU and ROUGE metrics.

Updates since v1.0.0rc6

Fixed

  • Lazy dataset readers now work correctly with multi-process data loading.
  • Fixed race conditions that could occur when using a dataset cache.

Added

  • A bug where where all datasets would be loaded for vocab creation even if not needed.
  • A parameter to the DatasetReader class: manual_multi_process_sharding. This is similar
    to the manual_distributed_sharding parameter, but applies when using a multi-process
    DataLoader.

Commits

29f3b6c Prepare for release v1.0.0
a8b840d fix some formatting issues in README (#4365)
d3ed619 fix Makefile
c554910 quick doc fixes (#4364)
b764bef simplify dataset classes, fix multi-process lazy loading (#4344)
884a614 Bump mkdocs-material from 5.2.3 to 5.3.0 (#4359)
6a124d8 ensure 'from_files' vocab doesn't load instances (#4356)
87c23e4 Fix handling of "datasets_for_vocab_creation" param (#4350)
c3755d1 update CHANGELOG

Upgrade guide from v0.9.0

There are too many changes to be exhaustive, but here is a list of the most common issues:

  • You can continue to use the allennlp command line, but if you want to invoke it through Python, use python -m allennlp <command> instead of python -m allennlp.run <command>.
  • "bert_adam" is now "adamw".
  • We no longer support the "gradient_accumulation_batch_size" parameter to the trainer. Use "num_gradient_accumulation_steps" instead.

Using the transformers library

AllenNLP 1.0 replaces the mash-mash of transformer libraries and dependencies that we had in v0.9.0, and replaces it with one implementation that uses https://github.com/huggingface/transformers under the hood. For cases where you can work directly with the word pieces that are used by the transformers, use "pretrained_transformer" for tokenizers, indexers, and embedders. If you want to use tokens from pre-tokenized text, use ""pretrained_transformer_mismatched". The latter turns the text into word pieces, embeds them with the transformer, and then combines word pieces to produce an embedding for the original tokens.

The parameters requires_grad and top_layer_only are no longer supported. If you are converting an old model that used to use "bert-pretrained", this is important! requires_grad used to be False by default, so it would not train the transformer itself. This saves memory and time at the cost of performance. The new code does not support this setting, and will always train the transformer. You can prevent this by setting requires_grad to False in a parameter group when setting up the optimizer.

You no longer need to specify do_lowercase, as this is handled automatically now.

Config file changes

In 1.0, we simplified how FromParams works. As a result, some things in the config files need to change to work with 1.0:

  • The way Vocabulary options are specified in config files has changed. See #3550. If you want to load a vocabulary from files, you should specify "type": "from_files", and use the key "directory" instead of "directory_path".
  • When instantiating a BasicTextFieldEmbedder from_params, you used to be able to have embedder names be top-level keys in the config file (e.g., "embedder": {"elmo": ELMO_PARAMS, "tokens": TOKEN_PARAMS}). We changed this a long time ago to prefer wrapping them in a "token_embedders" key, and this is now required (e.g., "embedder": {"token_embedders": {"elmo": ELMO_PARAMS, "tokens": TOKEN_PARAMS}}).
  • The TokenCharactersEncoder now requires you to specify the vocab_namespace for the underlying embedder. It used to default to "token_characters", matching the TokenCharactersIndexer default, but making that work required some custom magic that wasn't worth the complexity. So instead of "token_characters": {"type": "character_encoding", "embedding": {"embedding_dim": 25}, "encoder": {...}}, you need to change this to: "token_characters": {"type": "character_encoding", "embedding": {"embedding_dim": 25, "vocab_namespace": "token_characters"}, "encoder": {...}}
  • Regularization now needs another key in a config file. Instead of specifying regularization as "regularizer": [[regex1, regularizer_params], [regex2, regularizer_params]], it now must be specified as "regularizer": {"regexes": [[regex1, regularizer_params], [regex2, regularizer_params]]}.
  • We changed initialization in a similar way to regularization. Instead of specifying initialization as "initializer": [[regex1, initializer_params], [regex2, initializer_params]], it now must be specified as "initializer": {"regexes": [[regex1, initializer_params], [regex2, initializer_params]]}. Also, you used to be able to have initializer_params be "prevent", to prevent initialization of matching parameters. This is now done with a separate key passed to the initializer: `"initializer": {"regexes": [..], "prevent_regexes": [regex1, regex2]}.
  • num_serialized_models_to_keep and keep_serialized_model_every_num_seconds used to be able to be passed as top-level parameters to the trainer, but now they must always be passed to the checkpointer instead. For example, if you had "trainer": {"num_serialized_models_to_keep": 1}, it now needs to be "trainer": {"checkpointer": {"num_serialized_models_to_keep": 1}}. Also, the default for that setting is now 2, so AllenNLP will no longer fill up your hard drive!
  • Tokenizer specification changed because of #3361. Instead of something like "tokenizer": {"word_splitter": {"type": "spacy"}}, you now just do "tokenizer": {"type": "spacy"} (more technically: the WordTokenizer has now been removed, with the things we used to call WordSplitters now just moved up to be top-level Tokenizers themselves).
  • The namespace_to_cache argument to ElmoTokenEmbedder has been removed as a config file option. You can still pass vocab_to_cache to the constructor of this class, but this functionality is no longer available from a config file. If you used this and are really put out by this change, let us know, and we'll see what we can do.

Iterators ➔ DataLoaders

Allennlp now uses PyTorch's API for data iteration, rather than our own custom one. This means that train_data, validation_data, iterator and validation_iterator arguments to the Trainer have been removed and replaced with data_loader and validation_dataloader.

Previous config files which looked like:

{
  "iterator": {
    "type": "bucket",
    "sorting_keys": [["tokens"], ["num_tokens"]],
    "padding_noise": 0.1
    ...
  }
}

Now become:

{
  "data_loader": {
    "batch_sampler" {
      "type": "bucket",
      // sorting keys are no longer required! They can be inferred automatically.
      "padding_noise": 0.1
      ...
    }
  }
}

Multi-GPU

Allennlp now uses DistributedDataParallel for parallel training, rather than DataParallel. With DistributedDataParallel, each worker (GPU) runs in it's own process. As such, each process also has its own Trainer, which now takes a single GPU ID only.

Previous config files which looked like:

{
  "trainer": {
    "cuda_device": [0, 1, 2, 3],
    "num_epochs": 20,
    ...
  }
}

Now become:

{
  "distributed": {
    "cuda_devices": [0, 1, 2, 3],
  },
  "trainer": {
    "num_epochs": 20,
    ...
  }
}

Don't miss a new allennlp release

NewReleases is sending notifications on new releases.