DyNet documentation¶
News! The master branch is now DyNet version 2.0 (as of 6/28/2017), which contains a number of changes including a new model format, etc. If you’re looking for the old version, check out the v1.1 branch.
DyNet (formerly known as cnn) is a neural network library developed by Carnegie Mellon University and many others. It is written in C++ (with bindings in Python) and is designed to be efficient when run on either CPU or GPU, and to work well with networks that have dynamic structures that change for every training instance. For example, these kinds of networks are particularly important in natural language processing tasks, and DyNet has been used to build state-of-the-art systems for syntactic parsing, machine translation, morphological inflection, and many other application areas.
Read the documentation below to get started. If you have any problems see Debugging/Reporting Issues for how to debug and/or get in contact with the developers. We also greatly welcome contributions, so see Contributing to Dynet for details.
You can also read more technical details in our technical report. If you use DyNet for research, please cite this report as follows:
@article{dynet,
title={DyNet: The Dynamic Neural Network Toolkit},
author={Graham Neubig and Chris Dyer and Yoav Goldberg and Austin Matthews and Waleed Ammar and Antonios Anastasopoulos and Miguel Ballesteros and David Chiang and Daniel Clothiaux and Trevor Cohn and Kevin Duh and Manaal Faruqui and Cynthia Gan and Dan Garrette and Yangfeng Ji and Lingpeng Kong and Adhiguna Kuncoro and Gaurav Kumar and Chaitanya Malaviya and Paul Michel and Yusuke Oda and Matthew Richardson and Naomi Saphra and Swabha Swayamdipta and Pengcheng Yin},
journal={arXiv preprint arXiv:1701.03980},
year={2017}
}
DyNet can be installed according to the instructions below:
Installing DyNet for C++¶
How to build DyNet and link it with your C++ programs.
Prerequisites¶
DyNet relies on a number of external programs/libraries including CMake, Eigen, and Mercurial (to install Eigen). CMake, and Mercurial can be installed from standard repositories.
For example on Ubuntu Linux:
sudo apt-get install build-essential cmake mercurial
Or on macOS, first make sure the Apple Command Line Tools are installed, then get CMake, and Mercurial with either homebrew or macports:
xcode-select --install
brew install cmake hg # Using homebrew.
sudo port install cmake mercurial # Using macports.
On Windows, see Windows Support.
To compile DyNet you also need the development version of the Eigen library. If you use any of the released versions, you may get assertion failures or compile errors. If you don’t have Eigen already, you can get it easily using the following command:
hg clone https://bitbucket.org/eigen/eigen/ -r 699b659
The -r NUM specified a revision number that is known to work. Adventurous users can remove it and use the very latest version, at the risk of the code breaking / not compiling. On macOS, you can install the latest development of Eigen using Homebrew:
brew install --HEAD eigen
Building¶
To get and build DyNet, clone the repository
git clone https://github.com/clab/dynet.git
then enter the directory and use cmake
to generate the makefiles. When you run cmake
, you will need to specify
the path to Eigen, and will probably want to specify ENABLE_CPP_EXAMPLES
to compile the C++ examples.
cd dynet
mkdir build
cd build
cmake .. -DEIGEN3_INCLUDE_DIR=/path/to/eigen -DENABLE_CPP_EXAMPLES=ON
Then compile, where “2” can be replaced by the number of cores on your machine
make -j 2
To see that things have built properly, you can run
./examples/train_xor
which will train a multilayer perceptron to predict the xor function.
If any process here fails, please see Asking Questions/Reporting Bugs for help.
Compiling/linking external programs¶
When you want to use DyNet in an external program, you will need to add
the dynet
directory to the compile path:
-I/path/to/dynet
and link with the DyNet library:
-L/path/to/dynet/build/dynet -ldynet
GPU/cuDNN/MKL support¶
GPU (CUDA) support¶
DyNet supports running programs on GPUs with CUDA. If you have CUDA
installed, you can build DyNet with GPU support by adding
-DBACKEND=cuda
to your cmake options. The linking method is exactly
the same as with the CPU backend case.
-L/path/to/dynet/build/dynet -ldynet
If you know the CUDA architecture supported by your GPU (e.g. by referencing
this page)
you can speed compilation significantly by adding -DCUDA_ARCH=XXX
where
XXX
is your architecture number.
cuDNN support¶
When running DyNet with CUDA on GPUs, some of DyNet’s functionalities
(e.g. conv2d) will depend on the NVIDIA cuDNN libraries.
CMake will automatically detect cuDNN in the suggested installation path
by NVIDIA (i.e. /usr/local/cuda
) and enable those functionalities
if detected.
If CMake is unable to find cuDNN automatically, try setting CUDNN_ROOT, such as
-DCUDNN_ROOT="/path/to/CUDNN"
. However, if you don’t have cuDNN installed, those dependend functionalities will be automatically disabled and an error will be throwed during runtime if you try to use them.
Currently, DyNet supports cuDNN v5.1, future versions will also be supported soon.
MKL support¶
DyNet can leverage Intel’s MKL library to speed up computation on the CPU. As an example, we’ve seen 3x speedup in seq2seq training when using MKL. To use MKL, include the following cmake option:
-DMKL=TRUE
If CMake is unable to find MKL automatically, try setting MKL_ROOT, such as
-DMKL_ROOT="/path/to/MKL"
One common install location is /opt/intel/mkl/
.
If either MKL or MKL_ROOT are set, CMake will look for MKL.
By default, MKL will use all CPU cores. You can control how many cores MKL uses by setting the environment variable MKL_NUM_THREADS to the desired number. The following is the total time to process 250 training examples running the example encdec (on a 6 core Intel Xeon E5-1650):
encdec.exe --dynet-seed 1 --dynet-mem 1000 train-hsm.txt dev-hsm.txt
+-----------------+------------+---------+
| MKL_NUM_THREADS | Cores Used | Time(s) |
+-----------------+------------+---------+
| <Without MKL> | 1 | 28.6 |
| 1 | 1 | 13.3 |
| 2 | 2 | 9.5 |
| 3 | 3 | 8.1 |
| 4 | 4 | 7.8 |
| 6 | 6 | 8.2 |
+-----------------+------------+---------+
As you can see, for this particular example, using MKL roughly doubles the speed of computation while still using only one core. Increasing the number of cores to 2 or 3 is quite beneficial, but beyond that there are diminishing returns or even slowdown.
Compiling with Boost¶
DyNet requires Boost for a few pieces of less-commonly-used functionality
to be enabled (unit tests and multi-processing). Boost can be enabled by using the
-DENABLE_BOOST=ON
flag to cmake
. In general, DyNet will find
Boost it if it is in the standard
location. If Boost is in a non-standard location, say $HOME/boost
,
you can specify the location by adding the following to your CMake
options:
-DBOOST_ROOT:PATHNAME=$HOME/boost -DBoost_LIBRARY_DIRS:FILEPATH=$HOME/boost/lib
-DBoost_NO_BOOST_CMAKE=TRUE -DBoost_NO_SYSTEM_PATHS=TRUE
Note that you will also have to set your LD_LIBRARY_PATH``(``DYLD_LIBRARY_PATH
instead for osx) to point to
the boost/lib
directory.
Note also that Boost must be compiled with the same compiler version as
you are using to compile DyNet.
Windows Support¶
DyNet has been tested to build in Windows using Microsoft Visual Studio 2015. You may be able to build with MSVC 2013 by slightly modifying the instructions below.
First, install Eigen following the above instructions.
To generate the MSVC solution and project files, run cmake, pointing it to the location you installed Eigen (for example, at c:\libs\Eigen):
mkdir build
cd build
cmake .. -DEIGEN3_INCLUDE_DIR=c:/libs/Eigen -G"Visual Studio 14 2015 Win64"
This will generate dynet.sln. Simply open this and build all. Note: multi-process functionality is currently not supported in Windows, so the multi-process examples (`*-mp`) will not be included in the generated solution
The Windows build also supports MKL and CUDA with the latest version of Eigen. If you build with CUDA and/or cuDNN, ensure their respective DLLs are in your PATH environment variable when you use dynet (whether in native C++ or Python). For example:
set PATH="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin";"c:\libs\cudnn-8.0-windows10-x64-v5.1\bin";%PATH%
Installing DyNet for Python¶
Python bindings to DyNet are supported for both Python 2.x and 3.x. Before installing DyNet, you will need to make sure that several packages are installed. For example on Ubuntu Linux:
sudo apt-get update
sudo apt-get install python-pip build-essential cmake mercurial
Or on macOS, first make sure the Apple Command Line Tools are installed, then get CMake, and Mercurial with either homebrew or macports:
xcode-select --install
brew install cmake hg python # Using homebrew.
sudo port install cmake mercurial py-pip # Using macports.
On Windows, see Windows Support.
- (Currently, since the pip installation will build from source, you need to install
- cython ahead:
pip install cython
.)
Once these packages are installed, the following will download, build and install DyNet. Note that compiling DyNet may take a long time, up to 10 minutes or more, but as long as you see “Running setup.py install for dynet” with the moving progress wheel, things should be running.
pip install git+https://github.com/clab/dynet#egg=dynet
If you have CUDA installed on your system and want to install with GPU support, you can instead run the following command.
BACKEND=cuda pip install git+https://github.com/clab/dynet#egg=dynet
Alternatively, you can add the following to your requirements.txt (for CUDA support you will need to make sure that BACKEND=cuda is in your environmental variables when DyNet is installed):
git+https://github.com/clab/dynet#egg=dynet
You can also manually set the directory of the cuDNN library as follows:
CUDNN_ROOT=/path/to/cudnn BACKEND=cuda pip install git+https://github.com/clab/dynet#egg=dynet
If installation using pip fails, if you copy-and-paste the entire log that you get after running the pip command into a github issue, we will help you debug. You can also try to install DyNet manually as listed below.
Manual Installation¶
The following is a list of all the commands needed to perform a manual install:
# Installing Python DyNet:
pip install cython # if you don't have it already.
mkdir dynet-base
cd dynet-base
# getting dynet and eigen
git clone https://github.com/clab/dynet.git
hg clone https://bitbucket.org/eigen/eigen -r 699b659 # -r NUM specified a known working revision
cd dynet
mkdir build
cd build
# without GPU support (if you get an error that Eigen cannot be found, try using the full path to Eigen)
cmake .. -DEIGEN3_INCLUDE_DIR=../../eigen -DPYTHON=`which python`
# or with GPU support (if you get an error that Eigen cannot be found, try using the full path to Eigen)
cmake .. -DEIGEN3_INCLUDE_DIR=../../eigen -DPYTHON=`which python` -DBACKEND=cuda
make -j 2 # replace 2 with the number of available cores
cd python
python ../../setup.py build --build-dir=.. --skip-build install # add `--user` for a user-local install.
# this should suffice, but on some systems you may need to add the following line to your
# init files in order for the compiled .so files be accessible to Python.
# /path/to/dynet/build/dynet is the location in which libdynet.dylib resides.
export DYLD_LIBRARY_PATH=/path/to/dynet/build/dynet/:$DYLD_LIBRARY_PATH
# if the environment is Linux, use LD_LIBRARY_PATH instead.
export LD_LIBRARY_PATH=/path/to/dynet/build/dynet/:$LD_LIBRARY_PATH
To explain these one-by-one, first we get DyNet:
cd $HOME
mkdir dynet-base
cd dynet-base
git clone https://github.com/clab/dynet.git
cd dynet
git submodule init # To be consistent with DyNet's installation instructions.
git submodule update # To be consistent with DyNet's installation instructions.
Then get Eigen:
cd $HOME
cd dynet-base
hg clone https://bitbucket.org/eigen/eigen/ -r 346ecdb
(-r NUM specifies a known working revision of Eigen. You can remove this in order to get the bleeding edge Eigen, with the risk of some compile breaks, and the possible benefit of added optimizations.)
We also need to make sure the cython
module is installed. (you can
replace pip
with your favorite package manager, such as conda
,
or install within a virtual environment)
pip install cython
To simplify the following steps, we can set a bash variable to hold where we have saved the main directories of DyNet and Eigen. In case you have gotten DyNet and Eigen differently from the instructions above and saved them in different location(s), these variables will be helpful:
PATH_TO_DYNET=$HOME/dynet-base/dynet/
PATH_TO_EIGEN=$HOME/dynet-base/eigen/
Compile DyNet.
This is pretty much the same process as compiling DyNet, with the
addition of the -DPYTHON=
flag, pointing to the location of your
Python interpreter.
Assuming that the cmake
command found all the needed libraries and
didn’t fail, the make
command will take a while, and compile DyNet
as well as the Python bindings. You can change make -j 2
to a higher
number, depending on the available cores you want to use while
compiling.
You now have a working Python binding inside of build/dynet
. To
verify this is working:
cd $PATH_TO_DYNET/build/python
python
then, within Python:
import dynet as dy
print dy.__version__
pc = dy.ParameterCollection()
In order to install the module so that it is accessible from everywhere in the system, run the following:
cd $PATH_TO_DYNET/build/python
python ../../setup.py EIGEN3_INCLUDE_DIR=$PATH_TO_EIGEN build --build-dir=.. --skip-build install --user
The --user
switch will install the module in your local
site-packages, and works without root privileges. To install the module
to the system site-packages (for all users), or to the current virtualenv
(if you are on one), run python ../../setup.py EIGEN3_INCLUDE_DIR=$PATH_TO_EIGEN build --build-dir=.. --skip-build install
without this switch.
You should now have a working python binding (the dynet
module).
Note however that the installation relies on the compiled DyNet library
being in $PATH_TO_DYNET/build/dynet
, so make sure not to move it
from there.
Now, check that everything works:
cd $PATH_TO_DYNET
cd examples/python
python xor.py
python rnnlm.py rnnlm.py
Alternatively, if the following script works for you, then your installation is likely to be working:
import dynet as dy
pc = dy.ParameterCollection()
If it doesn’t work and you get an error similar to the following:
ImportError: dlopen(/Users/sneharajana/.python-eggs/dyNET-0.0.0-py2.7-macosx-10.11-intel.egg-tmp/_dynet.so, 2): Library not loaded: @rpath/libdynet.dylib
Referenced from: /Users/sneharajana/.python-eggs/dyNET-0.0.0-py2.7-macosx-10.11-intel.egg-tmp/_dynet.so
Reason: image not found``
then you may need to run the following (and add it to your shell init files):
# OSX export DYLD_LIBRARY_PATH=/path/to/dynet/build/dynet/:$DYLD_LIBRARY_PATH # Linux export LD_LIBRARY_PATH=/path/to/dynet/build/dynet/:$LD_LIBRARY_PATH
# /path/to/dynet/build/dynet is the location in which libdynet.so(libdynet.dylib under osx) resides.
Anaconda Support¶
Anaconda is a popular package management system for Python, and DyNet can be installed into this environment. First, make sure that you install all the necessary packages according to the instructions at the top of this page. Then create an Anaconda environment and activate it as below:
source activate my_environment_name
After this, you should be able to install using pip or manual installation as normal.
Windows Support¶
You can also use Python on Windows, including GPU and MKL support. For simplicity, we recommend using a Python distribution that already has Cython installed. The following has been tested to work:
- Install WinPython 2.7.10 (comes with Cython already installed).
- Compile DyNet according to the directions in the Windows C++ documentation (Windows Support), and additionally add the following flag when executing
cmake
:-DPYTHON=/path/to/your/python.exe
. - Open a command prompt and set
VS90COMNTOOLS
to the path to your Visual Studio “Common7/Tools” directory. One easy way to do this is a command such as:
set VS90COMNTOOLS=%VS140COMNTOOLS%
- Open dynet.sln from this command prompt and build the “Release” version of the solution.
- Follow the rest of the instructions above for testing the build and installing it for other users
Note, currently only the Release version works. Also, if you compile with CUDA and/or cuDNN, ensure their respective DLLs are in your PATH environment variable when you run Python.
GPU/MKL Support¶
Installing on GPU¶
For installing on a computer with GPU, first install CUDA. The following instructions assume CUDA is installed.
The installation process is pretty much the same, while adding the
-DBACKEND=cuda
flag to the cmake
stage:
cmake .. -DEIGEN3_INCLUDE_DIR=$PATH_TO_EIGEN -DPYTHON=$PATH_TO_PYTHON -DBACKEND=cuda
If you know the CUDA architecture supported by your GPU (e.g. by referencing
this page)
you can speed compilation significantly by adding -DCUDA_ARCH=XXX
where
XXX
is your architecture number.
If CUDA is installed in a non-standard location and cmake
cannot
find it, you can specify also
-DCUDA_TOOLKIT_ROOT_DIR=/path/to/cuda
.
Now, build the Python modules (as above, we assume Cython is installed):
After running make -j 2
, you should have the file _dynet.so
in the build/python
folder.
As before, cd build/python
followed by
python ../../setup.py EIGEN3_INCLUDE_DIR=$PATH_TO_EIGEN build --build-dir=.. --skip-build install --user
will install the module.
cuDNN support¶
When running DyNet with CUDA on GPUs, some of DyNet’s functionalities
(e.g. conv2d) will depend on the NVIDIA cuDNN libraries.
CMake will automatically detect cuDNN in the suggested installation path
by NVIDIA (i.e. /usr/local/cuda
) and enable those functionalities
if detected.
If CMake is unable to find cuDNN automatically, try setting CUDNN_ROOT, such as
-DCUDNN_ROOT="/path/to/CUDNN"
. However, if you don’t have cuDNN installed, those dependend functionalities will be automatically disabled and an error will be throwed during runtime if you try to use them.
Currently, DyNet supports cuDNN v5.1, future versions will also be supported soon.
Using the GPU from Python¶
The preferred way to make dynet use the GPU under Python is to import dynet as usual:
import dynet
Then tell it to use the GPU by using the commandline switch
--dynet-gpu
or the GPU switches detailed here when invoking the program. This option lets the
same code work with either the GPU or the CPU version depending on how
it is invoked.
Alternatively, you can also select whether the CPU or GPU should be
used by using dynet_config
module:
import dynet_config
dynet_config.set_gpu()
import dynet
This may be useful if you want to decide programmatically whether to
use the CPU or GPU. Importantly, it is not suggested to use import _dynet
any more.
Running with MKL¶
If you’ve built DyNet to use MKL (using -DMKL
or -DMKL_ROOT
), Python sometimes has difficulty finding
the MKL shared libraries. You can try setting LD_LIBRARY_PATH
to point to your MKL library directory.
If that doesn’t work, try setting the following environment variable (supposing, for example,
your MKL libraries are located at /opt/intel/mkl/lib/intel64
):
export LD_PRELOAD=/opt/intel/mkl/lib/intel64/libmkl_def.so:/opt/intel/mkl/lib/intel64/libmkl_avx2.so:/opt/intel/mkl/lib/intel64/libmkl_core.so:/opt/intel/mkl/lib/intel64/libmkl_intel_lp64.so:/opt/intel/mkl/lib/intel64/libmkl_intel_thread.so:/opt/intel/lib/intel64_lin/libiomp5.so
Installing/Using in Other Languages¶
DyNet mainly supports the C++ and Python bindings, but there are also bindings for other languages that have been contributed by the community.
Scala/Java¶
DyNet has Scala/Java Bindings developed by Joel Grus at the Allen Institute for Artificial Intelligence. Please see the README linked above for details.
And get the basic information to create programs and use models:
DyNet Tutorial¶
C++ Tutorial¶
See the tutorials for the C++ version of DyNet
Basic Tutorial¶
An illustration of how parameter collections are trained (for a simple logistic regression model) is below:
First, we set up the structure of the parameter collection.
Create a parameter collection, and an SGD trainer to update its parameters.
ParameterCollection pc;
SimpleSGDTrainer trainer(pc);
Create a “computation graph,” which will define the flow of information.
ComputationGraph cg;
Initialize a 1x3 parameter vector, and add the parameters to be part of the computation graph.
Expression W = parameter(cg, pc.add_parameters({1, 3}));
Create variables defining the input and output of the regression, and load them into the computation graph. Note that we don’t need to set concrete values yet.
vector<dynet::real> x_values(3);
Expression x = input(cg, {3}, &x_values);
dynet::real y_value;
Expression y = input(cg, &y_value);
Next, set up the structure to multiply the input by the weight vector, then run the output of this through a logistic sigmoid function logistic regression).
Expression y_pred = logistic(W*x);
Finally, we create a function to calculate the loss. The model will be optimized to minimize the value of the final function in the computation graph.
Expression l = binary_log_loss(y_pred, y);
We are now done setting up the graph, and we can print out its structure:
cg.print_graphviz();
Now, we perform a parameter update for a single example. Set the input/output to the values specified by the training data:
x_values = {0.5, 0.3, 0.7};
y_value = 1.0;
“forward” propagates values forward through the computation graph, and returns the loss.
dynet::real loss = as_scalar(cg.forward(l));
“backward” performs back-propagation, and accumulates the gradients of the parameters within the ParameterCollection
data structure.
cg.backward(l);
trainer.update
updates parameters of the parameter collection that was passed to its constructor. Here 1.0 is the scaling factor that allows us to control the size of the update.
trainer.update();
Note that this very simple example that doesn’t cover things like memory
initialization, reading/writing parameter collections, recurrent/LSTM networks, or
adding biases to functions. The best way to get an idea of how to use
DyNet for real is to look in the example
directory, particularly
starting with the simplest xor
example.
Saving and Loading¶
DyNet provides C++ interfaces for users to save and restore model parameters. The user has two options for saving a model. In the most basic use case, a complete ParameterCollection
object can be saved. At loading time, the user should define and allocate the same parameter variables that were present in the model when it was saved (this usually amounts to having the same parameter creation called by both code paths), and then call populate
and pass in the ParameterCollection
object containing the parameters that should be loaded.
#include <dynet/io.h>
// save end
ParameterCollection m;
Parameter a = m.add_parameters({100});
LookupParameter b = m.add_lookup_parameters(10, {100});
Parameter c = m.add_parameters({1000});
{
dynet::TextFileSaver s("/tmp/tmp.model");
s.save(m);
}
// load end
ParameterCollection m;
m.add_parameters({100});
m.add_lookup_parameters(10, {100});
m.add_parameters({1000});
{
dynet::TextFileLoader l("/tmp/tmp.model");
l.populate(m);
}
However, in some cases it is useful to save only a subset of parameter objects(for example, if users wish to load these in a pretraining setup). Here, Parameter
or LookupParameter
objects can be saved explicitly. User could also specify keys for partial saving and loading.
#include <dynet/io.h>
// save end
ParameterCollection m1, m2;
m1.add_parameters({10}, "a");
m1.add_lookup_parameters(10, {2}, "la");
Parameter param_b = m2.add_parameters({3, 7});
{
dynet::TextFileSaver s("/tmp/tmp.model");
s.save(m1, "/namespace_tmp/");
s.save(param_b, "param_b");
}
// load end
ParameterCollection m;
m.add_parameters({10});
m.add_lookup_parameters(10, {2});
{
dynet::TextFileLoader l("/tmp/tmp.model");
l.populate(m, "/namespace_tmp/");
Parameter param_b = m.add_parameters({3, 7});
l.populate(param_b, "param_b");
}
// load end
// user can use equivalent interfaces to load model parameters
ParameterCollection m;
Parameter param_a, param_b;
LookupParameter l_param;
{
dynet::TextFileLoader l("/tmp/tmp.model");
param_a = l.load_param(m, "/namespace_tmp/a");
l_param = l.load_lookup_param(m, "/namespace_tmp/la");
param_b = l.load_param(m, "param_b");
}
A word of warning: in previous versions of DyNet, Builder objects needed to be serialized. This is no longer the case. (The Python inerface does allow serialization of builder objects out of the box).
Currently, DyNet only supports plain text format. The native format is quite simple so very readable. The model file is consist of basic storage blocks. A basic block starts with a first line of meta data information: #object_type# object_name dimension block_size
and the remaining part of real data. During loading process, DyNet uses meta data lines to locate the objects user wants to load.
Python Tutorial¶
Guided examples in Python can be found below :
Working with the python DyNet package¶
The DyNet package is intended for training and using neural networks, and is particularly suited for applications with dynamically changing network structures. It is a python-wrapper for the DyNet C++ package.
In neural network packages there are generally two modes of operation:
- Static networks, in which a network is built and then being fed with different inputs/outputs. Most NN packages work this way.
- Dynamic networks, in which a new network is built for each training example (sharing parameters with the networks of other training examples). This approach is what makes DyNet unique, and where most of its power comes from.
We will describe both of these modes.
Package Fundamentals¶
The main piece of DyNet is the ComputationGraph
, which is what
essentially defines a neural network. The ComputationGraph
is
composed of expressions, which relate to the inputs and outputs of the
network, as well as the Parameters
of the network. The parameters
are the things in the network that are optimized over time, and all of
the parameters sit inside a ParameterCollection
. There are
trainers
(for example SimpleSGDTrainer
) that are in charge of
setting the parameter values.
We will not be using the ComputationGraph
directly, but it is there
in the background, as a singleton object. When dynet
is imported, a
new ComputationGraph
is created. We can then reset the computation
graph to a new state by calling renew_cg()
.
Static Networks¶
The life-cycle of a DyNet program is: 1. Create a
ParameterCollection
, and populate it with Parameters
. 2. Renew
the computation graph, and create Expression
representing the
network (the network will include the Expression
s for the
Parameters
defined in the parameter collection). 3. Optimize the
model for the objective of the network.
As an example, consider a model for solving the “xor” problem. The network has two inputs, which can be 0 or 1, and a single output which should be the xor of the two inputs. We will model this as a multi-layer perceptron with a single hidden layer.
Let \(x = x_1, x_2\) be our input. We will have a hidden layer of 8 nodes, and an output layer of a single node. The activation on the hidden layer will be a \(\tanh\). Our network will then be:
\(\sigma(V(\tanh(Wx+b)))\)
Where \(W\) is a \(8 \times 2\) matrix, \(V\) is an \(8 \times 1\) matrix, and \(b\) is an 8-dim vector.
We want the output to be either 0 or 1, so we take the output layer to be the logistic-sigmoid function, \(\sigma(x)\), that takes values between \(-\infty\) and \(+\infty\) and returns numbers in \([0,1]\).
We will begin by defining the model and the computation graph.
In [1]:
# we assume that we have the dynet module in your path.
# OUTDATED: we also assume that LD_LIBRARY_PATH includes a pointer to where libcnn_shared.so is.
import dynet as dy
In [2]:
# create a parameter collection and add the parameters.
m = dy.ParameterCollection()
pW = m.add_parameters((8,2))
pV = m.add_parameters((1,8))
pb = m.add_parameters((8))
renew_cg() # new computation graph. not strictly needed here, but good practice.
# associate the parameters with cg Expressions
W = parameter(pW)
V = parameter(pV)
b = parameter(pb)
In [3]:
#b[1:-1].value()
b.value()
Out[3]:
[-0.5920619964599609,
-0.4818088114261627,
-0.011437613517045975,
-0.7547096610069275,
0.2887613773345947,
-0.39806437492370605,
-0.8494511246681213,
0.295582115650177]
The first block creates a parameter collection and populates it with
parameters. The second block creates a computation graph and adds the
parameters to it, transforming them into Expression
s. The need to
distinguish model parameters from “expressions” will become clearer
later.
We now make use of the W and V expressions, in order to create the complete expression for the network.
In [4]:
x = vecInput(2) # an input vector of size 2. Also an expression.
output = logistic(V*(tanh((W*x)+b)))
In [5]:
# we can now query our network
x.set([0,0])
output.value()
Out[5]:
0.706532895565033
In [6]:
# we want to be able to define a loss, so we need an input expression to work against.
y = scalarInput(0) # this will hold the correct answer
loss = binary_log_loss(output, y)
In [7]:
x.set([1,0])
y.set(0)
print loss.value()
y.set(1)
print loss.value()
1.25551486015
0.335373580456
Training¶
We now want to set the parameter weights such that the loss is minimized.
For this, we will use a trainer object. A trainer is constructed with respect to the parameters of a given model.
In [8]:
trainer = SimpleSGDTrainer(m)
To use the trainer, we need to: * call the ``forward_scalar``
method of ComputationGraph
. This will run a forward pass through the
network, calculating all the intermediate values until the last one
(loss
, in our case), and then convert the value to a scalar. The
final output of our network must be a single scalar value. However,
if we do not care about the value, we can just use cg.forward()
instead of cg.forward_sclar()
. * call the ``backward`` method
of ComputationGraph
. This will run a backward pass from the last
node, calculating the gradients with respect to minimizing the last
expression (in our case we want to minimize the loss). The gradients are
stored in the parameter collection, and we can now let the trainer
take care of the optimization step. * call ``trainer.update()`` to
optimize the values with respect to the latest gradients.
In [9]:
x.set([1,0])
y.set(1)
loss_value = loss.value() # this performs a forward through the network.
print "the loss before step is:",loss_value
# now do an optimization step
loss.backward() # compute the gradients
trainer.update()
# see how it affected the loss:
loss_value = loss.value(recalculate=True) # recalculate=True means "don't use precomputed value"
print "the loss after step is:",loss_value
the loss before step is: 0.335373580456
the loss after step is: 0.296859383583
The optimization step indeed made the loss decrease. We now need to run
this in a loop. To this end, we will create a training set
, and
iterate over it.
For the xor problem, the training instances are easy to create.
In [10]:
def create_xor_instances(num_rounds=2000):
questions = []
answers = []
for round in xrange(num_rounds):
for x1 in 0,1:
for x2 in 0,1:
answer = 0 if x1==x2 else 1
questions.append((x1,x2))
answers.append(answer)
return questions, answers
questions, answers = create_xor_instances()
We now feed each question / answer pair to the network, and try to minimize the loss.
In [11]:
total_loss = 0
seen_instances = 0
for question, answer in zip(questions, answers):
x.set(question)
y.set(answer)
seen_instances += 1
total_loss += loss.value()
loss.backward()
trainer.update()
if (seen_instances > 1 and seen_instances % 100 == 0):
print "average loss is:",total_loss / seen_instances
average loss is: 0.730996069312
average loss is: 0.686455376148
average loss is: 0.614968097508
average loss is: 0.529396591447
average loss is: 0.454356552631
average loss is: 0.39492503399
average loss is: 0.348310606687
average loss is: 0.311234809482
average loss is: 0.281200638587
average loss is: 0.256437818106
average loss is: 0.235696636033
average loss is: 0.218082525641
average loss is: 0.202943060785
average loss is: 0.189793206944
average loss is: 0.178265773896
average loss is: 0.168078109015
average loss is: 0.15900931143
average loss is: 0.150884356805
average loss is: 0.143562835396
average loss is: 0.136930837112
average loss is: 0.130894997159
average loss is: 0.125378077089
average loss is: 0.120315633187
average loss is: 0.115653475622
average loss is: 0.111345707807
average loss is: 0.107353201057
average loss is: 0.103642390902
average loss is: 0.100184321725
average loss is: 0.0969538828368
average loss is: 0.0939291894056
average loss is: 0.0910910811149
average loss is: 0.0884227104994
average loss is: 0.0859092032744
average loss is: 0.0835373785728
average loss is: 0.0812955136038
average loss is: 0.0791731475857
average loss is: 0.0771609158713
average loss is: 0.0752504101568
average loss is: 0.0734340592178
average loss is: 0.0717050271845
average loss is: 0.0700571256665
average loss is: 0.0684847396141
average loss is: 0.0669827620572
average loss is: 0.0655465372522
average loss is: 0.0641718128339
average loss is: 0.0628546962203
average loss is: 0.0615916178524
average loss is: 0.0603792975615
average loss is: 0.0592147165184
average loss is: 0.0580950913344
average loss is: 0.0570178513814
average loss is: 0.0559806190546
average loss is: 0.0549811920022
average loss is: 0.0540175269391
average loss is: 0.0530877257938
average loss is: 0.0521900229302
average loss is: 0.0513227736969
average loss is: 0.0504844442235
average loss is: 0.0496736022536
average loss is: 0.0488889090025
average loss is: 0.0481291114653
average loss is: 0.0473930355647
average loss is: 0.0466795804093
average loss is: 0.0459877123818
average loss is: 0.0453164599289
average loss is: 0.0446649091876
average loss is: 0.0440321997496
average loss is: 0.0434175205679
average loss is: 0.0428201068594
average loss is: 0.042239236579
average loss is: 0.041674227424
average loss is: 0.0411244342562
average loss is: 0.0405892467939
average loss is: 0.0400680867989
average loss is: 0.0395604063634
average loss is: 0.0390656857708
average loss is: 0.0385834318376
average loss is: 0.0381131761705
average loss is: 0.037654473684
average loss is: 0.0372069010154
Our network is now trained. Let’s verify that it indeed learned the xor function:
In [12]:
x.set([0,1])
print "0,1",output.value()
x.set([1,0])
print "1,0",output.value()
x.set([0,0])
print "0,0",output.value()
x.set([1,1])
print "1,1",output.value()
0,1 0.998090803623
1,0 0.998076915741
0,0 0.00135990511626
1,1 0.00213058013469
In case we are curious about the parameter values, we can query them:
In [13]:
W.value()
Out[13]:
array([[ 1.26847982, 1.25287616],
[ 0.91610891, 0.80253637],
[ 3.18741179, -2.58643913],
[-0.82472938, -0.68830448],
[-2.74162889, 3.30151606],
[ 0.2677069 , 0.46926948],
[-2.60197234, -2.61786079],
[ 0.89582258, -0.44721049]])
In [14]:
V.value()
Out[14]:
array([[-2.33788562, -1.54022419, -4.58266163, -0.91096258, -4.88002253,
-0.70912606, -4.09791088, -0.61150461]])
In [15]:
b.value()
Out[15]:
[-1.9798537492752075,
-1.3854612112045288,
1.2350027561187744,
-0.8094932436943054,
1.3227168321609497,
-0.5688062906265259,
0.9074684381484985,
0.21831640601158142]
To summarize¶
Here is a complete program:
In [16]:
# define the parameters
m = ParameterCollection()
pW = m.add_parameters((8,2))
pV = m.add_parameters((1,8))
pb = m.add_parameters((8))
# renew the computation graph
renew_cg()
# add the parameters to the graph
W = parameter(pW)
V = parameter(pV)
b = parameter(pb)
# create the network
x = vecInput(2) # an input vector of size 2.
output = logistic(V*(tanh((W*x)+b)))
# define the loss with respect to an output y.
y = scalarInput(0) # this will hold the correct answer
loss = binary_log_loss(output, y)
# create training instances
def create_xor_instances(num_rounds=2000):
questions = []
answers = []
for round in xrange(num_rounds):
for x1 in 0,1:
for x2 in 0,1:
answer = 0 if x1==x2 else 1
questions.append((x1,x2))
answers.append(answer)
return questions, answers
questions, answers = create_xor_instances()
# train the network
trainer = SimpleSGDTrainer(m)
total_loss = 0
seen_instances = 0
for question, answer in zip(questions, answers):
x.set(question)
y.set(answer)
seen_instances += 1
total_loss += loss.value()
loss.backward()
trainer.update()
if (seen_instances > 1 and seen_instances % 100 == 0):
print "average loss is:",total_loss / seen_instances
average loss is: 0.725458401442
average loss is: 0.656036808193
average loss is: 0.563800293456
average loss is: 0.473188629244
average loss is: 0.401578919515
average loss is: 0.347210133697
average loss is: 0.30537398648
average loss is: 0.27243115149
average loss is: 0.245902155418
average loss is: 0.22411154042
average loss is: 0.205906257995
average loss is: 0.190473453378
average loss is: 0.177226172269
average loss is: 0.165731058566
average loss is: 0.155661680364
average loss is: 0.146767699362
average loss is: 0.138854031509
average loss is: 0.131766459678
average loss is: 0.125381493949
average loss is: 0.119599098227
average loss is: 0.114337381247
average loss is: 0.109528665657
average loss is: 0.105116533384
average loss is: 0.101053577985
average loss is: 0.0972996741069
average loss is: 0.093820632044
average loss is: 0.0905871372991
average loss is: 0.0875739114509
average loss is: 0.0847590394488
average loss is: 0.0821234288742
average loss is: 0.079650368163
average loss is: 0.0773251660003
average loss is: 0.0751348558335
average loss is: 0.0730679483965
average loss is: 0.0711142273374
average loss is: 0.0692645774255
average loss is: 0.0675108397355
average loss is: 0.0658456894337
average loss is: 0.0642625315812
average loss is: 0.0627554119665
average loss is: 0.0613189413034
average loss is: 0.059948229676
average loss is: 0.0586388300699
average loss is: 0.05738668844
average loss is: 0.0561881021362
average loss is: 0.0550396820511
average loss is: 0.0539383201534
average loss is: 0.0528811609025
average loss is: 0.0518655761557
average loss is: 0.0508891425877
average loss is: 0.0499496224367
average loss is: 0.0490449456893
average loss is: 0.0481731953563
average loss is: 0.0473325925335
average loss is: 0.0465214848134
average loss is: 0.0457383351514
average loss is: 0.0449817118815
average loss is: 0.0442502796927
average loss is: 0.0435427918518
average loss is: 0.0428580828441
average loss is: 0.0421950617608
average loss is: 0.0415527067172
average loss is: 0.0409300591527
average loss is: 0.0403262192239
average loss is: 0.0397403411381
average loss is: 0.0391716292271
average loss is: 0.0386193343495
average loss is: 0.0380827505725
average loss is: 0.0375612118193
average loss is: 0.0370540894219
average loss is: 0.0365607894682
average loss is: 0.0360807502221
average loss is: 0.0356134402267
average loss is: 0.0351583559568
average loss is: 0.0347150203697
average loss is: 0.0342829808685
average loss is: 0.0338618080745
average loss is: 0.0334510939502
average loss is: 0.0330504509121
average loss is: 0.0326595103741
Dynamic Networks¶
Dynamic networks are very similar to static ones, but instead of creating the network once and then calling “set” in each training example to change the inputs, we just create a new network for each training example.
We present an example below. While the value of this may not be clear in
the xor
example, the dynamic approach is very convenient for
networks for which the structure is not fixed, such as recurrent or
recursive networks.
In [17]:
import dynet as dy
# create training instances, as before
def create_xor_instances(num_rounds=2000):
questions = []
answers = []
for round in xrange(num_rounds):
for x1 in 0,1:
for x2 in 0,1:
answer = 0 if x1==x2 else 1
questions.append((x1,x2))
answers.append(answer)
return questions, answers
questions, answers = create_xor_instances()
# create a network for the xor problem given input and output
def create_xor_network(pW, pV, pb, inputs, expected_answer):
dy.renew_cg() # new computation graph
W = dy.parameter(pW) # add parameters to graph as expressions
V = dy.parameter(pV)
b = dy.parameter(pb)
x = dy.vecInput(len(inputs))
x.set(inputs)
y = dy.scalarInput(expected_answer)
output = dy.logistic(V*(dy.tanh((W*x)+b)))
loss = dy.binary_log_loss(output, y)
return loss
m2 = dy.ParameterCollection()
pW = m2.add_parameters((8,2))
pV = m2.add_parameters((1,8))
pb = m2.add_parameters((8))
trainer = dy.SimpleSGDTrainer(m2)
seen_instances = 0
total_loss = 0
for question, answer in zip(questions, answers):
loss = create_xor_network(pW, pV, pb, question, answer)
seen_instances += 1
total_loss += loss.value()
loss.backward()
trainer.update()
if (seen_instances > 1 and seen_instances % 100 == 0):
print "average loss is:",total_loss / seen_instances
average loss is: 0.736730417013
average loss is: 0.725369692743
average loss is: 0.715208243926
average loss is: 0.698906037733
average loss is: 0.667973376453
average loss is: 0.620016210104
average loss is: 0.564173455558
average loss is: 0.511108190748
average loss is: 0.464656613212
average loss is: 0.424903827408
average loss is: 0.390944672838
average loss is: 0.361782596097
average loss is: 0.336552875967
average loss is: 0.314552738269
average loss is: 0.295221981726
average loss is: 0.27811523865
average loss is: 0.262876965393
average loss is: 0.249221329002
average loss is: 0.236916671552
average loss is: 0.225773662324
average loss is: 0.215636288271
average loss is: 0.206374970573
average loss is: 0.197881278039
average loss is: 0.190063834667
average loss is: 0.182845127269
average loss is: 0.176158992879
average loss is: 0.16994863152
average loss is: 0.164165015582
average loss is: 0.158765610311
average loss is: 0.153713339384
average loss is: 0.148975738776
average loss is: 0.14452426397
average loss is: 0.140333718062
average loss is: 0.13638177571
average loss is: 0.132648585576
average loss is: 0.129116437846
average loss is: 0.125769484215
average loss is: 0.122593499324
average loss is: 0.119575678358
average loss is: 0.116704463887
average loss is: 0.113969398874
average loss is: 0.111360997359
average loss is: 0.108870635643
average loss is: 0.106490455879
average loss is: 0.104213282756
average loss is: 0.102032551605
average loss is: 0.0999422444205
average loss is: 0.0979368338955
average loss is: 0.0960112348951
average loss is: 0.094160760665
average loss is: 0.0923810851444
average loss is: 0.0906682085468
average loss is: 0.0890184267577
average loss is: 0.0874283051604
average loss is: 0.0858946543594
average loss is: 0.0844145084265
average loss is: 0.0829851059784
average loss is: 0.0816038727351
average loss is: 0.0802684055211
average loss is: 0.0789764590814
average loss is: 0.0777259325812
average loss is: 0.0765148587798
average loss is: 0.0753413928689
average loss is: 0.0742038039022
average loss is: 0.073100465403
average loss is: 0.072029847966
average loss is: 0.0709905121502
average loss is: 0.0699811016467
average loss is: 0.0690003377412
average loss is: 0.0680470136383
average loss is: 0.0671199895066
average loss is: 0.0662181878878
average loss is: 0.0653405894968
average loss is: 0.0644862291951
average loss is: 0.0636541927901
average loss is: 0.0628436133573
average loss is: 0.062053668331
average loss is: 0.0612835769022
average loss is: 0.0605325971122
average loss is: 0.0598000235481
API tutorial¶
Expression building¶
(note: may have old API in some cases)
In [ ]:
import dynet as dy
## ==== Create a new computation graph
# (it is a singleton, we have one at each stage.
# dy.renew_cg() clears the current one and starts anew)
dy.renew_cg()
## ==== Creating Expressions from user input / constants.
x = dy.scalarInput(value)
v = dy.vecInput(dimension)
v.set([1,2,3])
z = dy.matInput(dim1, dim2)
# for example:
z1 = dy.matInput(2, 2)
z1.set([1,2,3,4]) # Column major
# Or directly from a numpy array
z1 = inputTensor([[1,2],[3,4]]) # Row major
## ==== We can take the value of an expression.
# For complex expressions, this will run forward propagation.
print z.value()
print z.npvalue() # as numpy array
print v.vec_value() # as vector, if vector
print x.scalar_value() # as scalar, if scalar
print x.value() # choose the correct one
## ==== Parameters
# Parameters are things we tune during training.
# Usually a matrix or a vector.
# First we create a parameter collection and add the parameters to it.
m = ParameterCollection()
pW = m.add_parameters((8,8)) # an 8x8 matrix
pb = m.add_parameters(8)
# then we create an Expression out of the parameter collection's parameters
W = dy.parameter(pW)
b = dy.parameter(pb)
## ===== Lookup parameters
# Similar to parameters, but are representing a "lookup table"
# that maps numbers to vectors.
# These are used for embedding matrices.
# for example, this will have VOCAB_SIZE rows, each of DIM dimensions.
lp = m.add_lookup_parameters((VOCAB_SIZE, DIM))
# lookup parameters can be initialized from an existing array, i.e:
# m["lookup"].init_from_array(wv)
e5 = dy.lookup(lp, 5) # create an Expression from row 5.
e5 = lp[5] # same
e5c = dy.lookup(lp, 5, update=False) # as before, but don't update when optimizing.
e5 = dy.lookup_batch(lp, [4, 5]) # create a batched Expression from rows 4 and 5.
e5 = lp.batch([4, 5]) # same
e5.set(9) # now the e5 expression contains row 9
e5c.set(9) # ditto
## ===== Combine expression into complex expressions.
# Math
e = e1 + e2
e = e1 * e2 # for vectors/matrices: matrix multiplication (like e1.dot(e2) in numpy)
e = e1 - e2
e = -e1
e = dy.dot_product(e1, e2)
e = dy.cmult(e1, e2) # component-wise multiply (like e1*e2 in numpy)
e = dy.cdiv(e1, e2) # component-wise divide
e = dy.colwise_add(e1, e2) # column-wise addition
# Matrix Shapes
e = dy.reshape(e1, new_dimension)
e = dy.transpose(e1)
# Per-element unary functions.
e = dy.tanh(e1)
e = dy.exp(e1)
e = dy.log(e1)
e = dy.logistic(e1) # Sigmoid(x)
e = dy.rectify(e1) # Relu (= max(x,0))
e = dy.softsign(e1) # x/(1+|x|)
# softmaxes
e = dy.softmax(e1)
e = dy.log_softmax(e1, restrict=[]) # restrict is a set of indices.
# if not empty, only entries in restrict are part
# of softmax computation, others get 0.
e = dy.sum_cols(e1)
# Picking values from vector expressions
e = dy.pick(e1, k) # k is unsigned integer, e1 is vector. return e1[k]
e = e1[k] # same
e = dy.pickrange(e1, k, v) # like python's e1[k:v] for lists. e1 is an Expression, k,v integers.
e = e1[k:v] # same
e = dy.pickneglogsoftmax(e1, k) # k is unsigned integer. equiv to: (pick(-log(dy.softmax(e1)), k))
# Neural net stuff
dy.noise(e1, stddev) # add a noise to each element from a gausian with standard-dev = stddev
dy.dropout(e1, p) # apply dropout with probability p
# functions over lists of expressions
e = dy.esum([e1, e2, ...]) # sum
e = dy.average([e1, e2, ...]) # average
e = dy.concatenate_cols([e1, e2, ...]) # e1, e2,.. are column vectors. return a matrix. (sim to np.hstack([e1,e2,...])
e = dy.concatenate([e1, e2, ...]) # concatenate
e = dy.affine_transform([e0,e1,e2, ...]) # e = e0 + ((e1*e2) + (e3*e4) ...)
## Loss functions
e = dy.squared_distance(e1, e2)
e = dy.l1_distance(e1, e2)
e = dy.huber_distance(e1, e2, c=1.345)
# e1 must be a scalar that is a value between 0 and 1
# e2 (ty) must be a scalar that is a value between 0 and 1
# e = ty * log(e1) + (1 - ty) * log(1 - e1)
e = dy.binary_log_loss(e1, e2)
# e1 is row vector or scalar
# e2 is row vector or scalar
# m is number
# e = max(0, m - (e1 - e2))
e = dy.pairwise_rank_loss(e1, e2, m=1.0)
# Convolutions
# e1 \in R^{d x s} (input)
# e2 \in R^{d x m} (filter)
e = dy.conv1d_narrow(e1, e2) # e = e1 *conv e2
e = dy.conv1d_wide(e1, e2) # e = e1 *conv e2
e = dy.filter1d_narrow(e1, e2) # e = e1 *filter e2
e = dy.kmax_pooling(e1, k) # kmax-pooling operation (Kalchbrenner et al 2014)
e = dy.kmh_ngram(e1, k) #
e = dy.fold_rows(e1, nrows=2) #
Recipe¶
In [6]:
import dynet as dy
# create parameter collection
m = dy.ParameterCollection()
# add parameters to parameter collection
pW = m.add_parameters((10,30))
pB = m.add_parameters(10)
lookup = m.add_lookup_parameters((500, 10))
print "added"
# create trainer
trainer = dy.SimpleSGDTrainer(m)
# Regularization is set via the --dynet-l2 commandline flag.
# Learning rate parameters can be passed to the trainer:
# alpha = 0.1 # learning rate
# trainer = dy.SimpleSGDTrainer(m, e0=alpha)
# function for graph creation
def create_network_return_loss(inputs, expected_output):
"""
inputs is a list of numbers
"""
dy.renew_cg()
W = dy.parameter(pW) # from parameters to expressions
b = dy.parameter(pB)
emb_vectors = [lookup[i] for i in inputs]
net_input = dy.concatenate(emb_vectors)
net_output = dy.softmax( (W*net_input) + b)
loss = -dy.log(dy.pick(net_output, expected_output))
return loss
# function for prediction
def create_network_return_best(inputs):
"""
inputs is a list of numbers
"""
dy.renew_cg()
W = dy.parameter(pW)
b = dy.parameter(pB)
emb_vectors = [lookup[i] for i in inputs]
net_input = dy.concatenate(emb_vectors)
net_output = dy.softmax( (W*net_input) + b)
return np.argmax(net_output.npvalue())
# train network
for epoch in xrange(5):
for inp,lbl in ( ([1,2,3],1), ([3,2,4],2) ):
print inp, lbl
loss = create_network_return_loss(inp, lbl)
print loss.value() # need to run loss.value() for the forward prop
loss.backward()
trainer.update()
print create_network_return_best([1,2,3])
added
[1, 2, 3] 1
2.71492385864
[3, 2, 4] 2
2.48228144646
[1, 2, 3] 1
2.00279903412
[3, 2, 4] 2
1.82602763176
[1, 2, 3] 1
1.44809651375
[3, 2, 4] 2
1.34181213379
[1, 2, 3] 1
1.03570735455
[3, 2, 4] 2
0.988352060318
[1, 2, 3] 1
0.744616270065
[3, 2, 4] 2
0.732948303223
1
Recipe (using classes)¶
In [4]:
import dynet as dy
# create parameter collection
m = dy.ParameterCollection()
# create a class encapsulating the network
class OurNetwork(object):
# The init method adds parameters to the parameter collection.
def __init__(self, pc):
self.pW = pc.add_parameters((10,30))
self.pB = pc.add_parameters(10)
self.lookup = pc.add_lookup_parameters((500,10))
# the __call__ method applies the network to an input
def __call__(self, inputs):
W = dy.parameter(self.pW)
b = dy.parameter(self.pB)
lookup = self.lookup
emb_vectors = [lookup[i] for i in inputs]
net_input = dy.concatenate(emb_vectors)
net_output = dy.softmax( (W*net_input) + b)
return net_output
def create_network_return_loss(self, inputs, expected_output):
dy.renew_cg()
out = self(inputs)
loss = -dy.log(dy.pick(out, expected_output))
return loss
def create_network_return_best(self, inputs):
dy.renew_cg()
out = self(inputs)
return np.argmax(out.npvalue())
# create network
network = OurNetwork(m)
# create trainer
trainer = dy.SimpleSGDTrainer(m)
# train network
for epoch in xrange(5):
for inp,lbl in ( ([1,2,3],1), ([3,2,4],2) ):
print inp, lbl
loss = network.create_network_return_loss(inp, lbl)
print loss.value() # need to run loss.value() for the forward prop
loss.backward()
trainer.update()
print
print network.create_network_return_best([1,2,3])
[1, 2, 3] 1
2.5900914669
[3, 2, 4] 2
2.00347089767
[1, 2, 3] 1
1.98409461975
[3, 2, 4] 2
1.50869822502
[1, 2, 3] 1
1.50195622444
[3, 2, 4] 2
1.12316584587
[1, 2, 3] 1
1.12293696404
[3, 2, 4] 2
0.831095397472
[1, 2, 3] 1
0.833912611008
[3, 2, 4] 2
0.61754822731
1
or, alternatively, have the training outside of the network class¶
In [ ]:
# create network
network = OurNetwork(m)
# create trainer
trainer = dy.SimpleSGDTrainer(m)
# train network
for epoch in xrange(5):
for inp,lbl in ( ([1,2,3],1), ([3,2,4],2) ):
print inp, lbl
dy.renew_cg()
out = network(inp)
loss = -dy.log(dy.pick(out, lbl))
print loss.value() # need to run loss.value() for the forward prop
loss.backward()
trainer.update()
print
print np.argmax(network([1,2,3]).npvalue())
[1, 2, 3] 1
3.63615298271
[3, 2, 4] 2
3.29473733902
[1, 2, 3] 1
2.81605744362
[3, 2, 4] 2
2.46070289612
[1, 2, 3] 1
2.13946056366
[3, 2, 4] 2
1.77259361744
[1, 2, 3] 1
1.57904195786
[3, 2, 4] 2
1.2269589901
[1, 2, 3] 1
1.13014268875
[3, 2, 4] 2
0.830479979515
1
RNNs tutorial¶
In [1]:
# we assume that we have the dynet module in your path.
# OUTDATED: we also assume that LD_LIBRARY_PATH includes a pointer to where libcnn_shared.so is.
import dynet as dy
An LSTM/RNN overview:¶
An (1-layer) RNN can be thought of as a sequence of cells, \(h_1,...,h_k\), where \(h_i\) indicates the time dimenstion.
Each cell \(h_i\) has an input \(x_i\) and an output \(r_i\). In addition to \(x_i\), cell \(h_i\) receives as input also \(r_{i-1}\).
In a deep (multi-layer) RNN, we don’t have a sequence, but a grid. That is we have several layers of sequences:
- \(h_1^3,...,h_k^3\)
- \(h_1^2,...,h_k^2\)
- \(h_1^1,...h_k^1\),
Let \(r_i^j\) be the output of cell \(h_i^j\). Then:
The input to \(h_i^1\) is \(x_i\) and \(r_{i-1}^1\).
The input to \(h_i^2\) is \(r_i^1\) and \(r_{i-1}^2\), and so on.
The LSTM (RNN) Interface¶
RNN / LSTM / GRU follow the same interface. We have a “builder” which is in charge of creating definining the parameters for the sequence.
In [2]:
pc = dy.ParameterCollection()
NUM_LAYERS=2
INPUT_DIM=50
HIDDEN_DIM=10
builder = dy.LSTMBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
# or:
# builder = dy.SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
Note that when we create the builder, it adds the internal RNN
parameters to the ParameterCollection
. We do not need to care about
them, but they will be optimized together with the rest of the network’s
parameters.
In [3]:
s0 = builder.initial_state()
In [4]:
x1 = dy.vecInput(INPUT_DIM)
In [5]:
s1=s0.add_input(x1)
y1 = s1.output()
# here, we add x1 to the RNN, and the output we get from the top is y (a HIDEN_DIM-dim vector)
In [6]:
y1.npvalue().shape
Out[6]:
(10,)
In [7]:
s2=s1.add_input(x1) # we can add another input
y2=s2.output()
If our LSTM/RNN was one layer deep, y2 would be equal to the hidden state. However, since it is 2 layers deep, y2 is only the hidden state (= output) of the last layer.
If we were to want access to the all the hidden state (the output of
both the first and the last layers), we could use the .h()
method,
which returns a list of expressions, one for each layer:
In [8]:
print s2.h()
(exprssion 54/0, exprssion 66/0)
The same interface that we saw until now for the LSTM, holds also for the Simple RNN:
In [9]:
# create a simple rnn builder
rnnbuilder=dy.SimpleRNNBuilder(NUM_LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
# initialize a new graph, and a new sequence
rs0 = rnnbuilder.initial_state()
# add inputs
rs1 = rs0.add_input(x1)
ry1 = rs1.output()
print "all layers:", s1.h()
all layers: (exprssion 32/0, exprssion 42/0)
In [10]:
print s1.s()
(exprssion 28/0, exprssion 38/0, exprssion 32/0, exprssion 42/0)
To summarize, when calling .add_input(x)
on an RNNState
what
happens is that the state creates a new RNN/LSTM column, passing it: 1.
the state of the current RNN column 2. the input x
The state is then returned, and we can call it’s output()
method to
get the output y
, which is the output at the top of the column. We
can access the outputs of all the layers (not only the last one) using
the .h()
method of the state.
``.s()`` The internal state of the RNN may be more involved than
just the outputs \(h\). This is the case for the LSTM, that keeps an
extra “memory” cell, that is used when calculating \(h\), and which
is also passed to the next column. To access the entire hidden state, we
use the .s()
method.
The output of .s()
differs by the type of RNN being used. For the
simple-RNN, it is the same as .h()
. For the LSTM, it is more
involved.
In [11]:
rnn_h = rs1.h()
rnn_s = rs1.s()
print "RNN h:", rnn_h
print "RNN s:", rnn_s
lstm_h = s1.h()
lstm_s = s1.s()
print "LSTM h:", lstm_h
print "LSTM s:", lstm_s
RNN h: (exprssion 74/0, exprssion 76/0)
RNN s: (exprssion 74/0, exprssion 76/0)
LSTM h: (exprssion 32/0, exprssion 42/0)
LSTM s: (exprssion 28/0, exprssion 38/0, exprssion 32/0, exprssion 42/0)
As we can see, the LSTM has two extra state expressions (one for each hidden layer) before the outputs h.
Extra options in the RNN/LSTM interface¶
Stack LSTM The RNN’s are shaped as a stack: we can remove the top
and continue from the previous state. This is done either by remembering
the previous state and continuing it with a new .add_input()
, or
using we can access the previous state of a given state using the
.prev()
method of state.
Initializing a new sequence with a given state When we call
builder.initial_state()
, we are assuming the state has random /0
initialization. If we want, we can specify a list of expressions that
will serve as the initial state. The expected format is the same as the
results of a call to .final_s()
. TODO: this is not supported yet.
In [12]:
s2=s1.add_input(x1)
s3=s2.add_input(x1)
s4=s3.add_input(x1)
# let's continue s3 with a new input.
s5=s3.add_input(x1)
# we now have two different sequences:
# s0,s1,s2,s3,s4
# s0,s1,s2,s3,s5
# the two sequences share parameters.
assert(s5.prev() == s3)
assert(s4.prev() == s3)
s6=s3.prev().add_input(x1)
# we now have an additional sequence:
# s0,s1,s2,s6
In [13]:
s6.h()
Out[13]:
(exprssion 184/0, exprssion 196/0)
In [14]:
s6.s()
Out[14]:
(exprssion 180/0, exprssion 192/0, exprssion 184/0, exprssion 196/0)
Aside: memory efficient transduction¶
The RNNState
interface is convenient, and allows for incremental
input construction. However, sometimes we know the sequence of inputs in
advance, and care only about the sequence of output expressions. In this
case, we can use the add_inputs(xs)
method, where xs
is a list
of Expression.
In [15]:
state = rnnbuilder.initial_state()
xs = [x1,x1,x1]
states = state.add_inputs(xs)
outputs = [s.output() for s in states]
hs = [s.h() for s in states]
print outputs, hs
[exprssion 200/0, exprssion 206/0, exprssion 212/0] [(exprssion 198/0, exprssion 200/0), (exprssion 203/0, exprssion 206/0), (exprssion 209/0, exprssion 212/0)]
This is convenient.
What if we do not care about .s()
and .h()
, and do not need to
access the previous vectors? In such cases we can use the
transduce(xs)
method instead of add_inputs(xs)
. transduce
takes in a sequence of Expression
s, and returns a sequence of
Expression
s. As a consequence of not returning RNNState
s,
trnasduce
is much more memory efficient than add_inputs
or a
series of calls to add_input
.
In [16]:
state = rnnbuilder.initial_state()
xs = [x1,x1,x1]
outputs = state.transduce(xs)
print outputs
[exprssion 216/0, exprssion 222/0, exprssion 228/0]
Character-level LSTM¶
Now that we know the basics of RNNs, let’s build a character-level LSTM language-model. We have a sequence LSTM that, at each step, gets as input a character, and needs to predict the next character.
In [17]:
import random
from collections import defaultdict
from itertools import count
import sys
LAYERS = 2
INPUT_DIM = 50
HIDDEN_DIM = 50
characters = list("abcdefghijklmnopqrstuvwxyz ")
characters.append("<EOS>")
int2char = list(characters)
char2int = {c:i for i,c in enumerate(characters)}
VOCAB_SIZE = len(characters)
In [18]:
pc = dy.ParameterCollection()
srnn = dy.SimpleRNNBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
lstm = dy.LSTMBuilder(LAYERS, INPUT_DIM, HIDDEN_DIM, pc)
params = {}
params["lookup"] = pc.add_lookup_parameters((VOCAB_SIZE, INPUT_DIM))
params["R"] = pc.add_parameters((VOCAB_SIZE, HIDDEN_DIM))
params["bias"] = pc.add_parameters((VOCAB_SIZE))
# return compute loss of RNN for one sentence
def do_one_sentence(rnn, sentence):
# setup the sentence
dy.renew_cg()
s0 = rnn.initial_state()
R = dy.parameter(params["R"])
bias = dy.parameter(params["bias"])
lookup = params["lookup"]
sentence = ["<EOS>"] + list(sentence) + ["<EOS>"]
sentence = [char2int[c] for c in sentence]
s = s0
loss = []
for char,next_char in zip(sentence,sentence[1:]):
s = s.add_input(lookup[char])
probs = dy.softmax(R*s.output() + bias)
loss.append( -dy.log(dy.pick(probs,next_char)) )
loss = dy.esum(loss)
return loss
# generate from model:
def generate(rnn):
def sample(probs):
rnd = random.random()
for i,p in enumerate(probs):
rnd -= p
if rnd <= 0: break
return i
# setup the sentence
dy.renew_cg()
s0 = rnn.initial_state()
R = dy.parameter(params["R"])
bias = dy.parameter(params["bias"])
lookup = params["lookup"]
s = s0.add_input(lookup[char2int["<EOS>"]])
out=[]
while True:
probs = dy.softmax(R*s.output() + bias)
probs = probs.vec_value()
next_char = sample(probs)
out.append(int2char[next_char])
if out[-1] == "<EOS>": break
s = s.add_input(lookup[next_char])
return "".join(out[:-1]) # strip the <EOS>
# train, and generate every 5 samples
def train(rnn, sentence):
trainer = dy.SimpleSGDTrainer(pc)
for i in xrange(200):
loss = do_one_sentence(rnn, sentence)
loss_value = loss.value()
loss.backward()
trainer.update()
if i % 5 == 0:
print loss_value,
print generate(rnn)
Notice that: 1. We pass the same rnn-builder to do_one_sentence
over
and over again. We must re-use the same rnn-builder, as this is where
the shared parameters are kept. 2. We dy.renew_cg()
before each
sentence – because we want to have a new graph (new network) for this
sentence. The parameters will be shared through the model and the shared
rnn-builder.
In [19]:
sentence = "a quick brown fox jumped over the lazy dog"
train(srnn, sentence)
142.737915039 lvawhaevbxulc yxg esuh vkyb gymj dzcnwgq dcjzzk
84.1147460938 woifoa odp jpt gxjofkaattj
44.212223053 a q io uoopr ouxducmwi jfxa j
23.4485988617 p tctflr
9.73490333557 w
3.23773050308 yaqzteu pux oa rntd bxumu yyvvfalejuyhed over the lazy dog
1.06309330463 a quick browe fow jumped over the lazy dog
0.671298980713 a quick broyn ox jumped over the lazy dog
0.490513861179 a quick brown fox jumped over the lazy dog
0.386095941067 a quick brown fox jumped over the lazy dog
0.318082690239 a quick brown fox jumped over the lazy dog
0.270276993513 a quick brown fox jumped over the lazy dog
0.234851941466 a quick brown foz jumped over the lazy dog
0.207555636764 a quick brown fox jumped over the lazy dog
0.185884565115 a quick brown fox jumped over the lazy dog
0.168265148997 a quiuk brown fox jumped over jhe lazy dog
0.153665527701 a quick brown fox jumped over the lazy dog
0.141367897391 a quick brown fox jumped over the lazy dog
0.130873680115 a quick brown fox jumped over the lazy dog
0.121810980141 a quick brown fox jumped over the lazy dog
0.113908931613 a quick brown fox jumped over the lazy dog
0.106958284974 a quick brown fox jumped over the lazy dog
0.100796818733 a quick brown fox jumped over the lazy dog
0.0953008085489 a quick brown fox jumped over the lazy dog
0.090367347002 a zuick brown for jumped over the lazy dog
0.0859087407589 a quick brown fox jumped over the lazy dog
0.0818664133549 a quick brown fox jumped over the lazy dog
0.0781841799617 a quick brown fox jumped over the lazy dog
0.0748091414571 a quick brown fox jumped over the lazy dog
0.0717144161463 a quick brown fox jumped over the lazy dog
0.0688648074865 a quick brown fox jumped over the lazy dog
0.0662328600883 a quick brown fox jumped over the lazy dog
0.0637853741646 a quick brown fox jumped over the lazy dog
0.0615109689534 a quick brown fox jumped over the lazy dog
0.0593910999596 a quick brown fox jumped over the lazy dog
0.0574130378664 a quick brown fox jumped over the lazy dog
0.0555621087551 a quick brown fox jumped over the lazy dog
0.0538215488195 a quick brown fox jumped over the lazy dog
0.0521896965802 a quick brown fox jumped over the lazy dog
0.0506477579474 a quick brown fox jumped over the lazy dog
In [20]:
sentence = "a quick brown fox jumped over the lazy dog"
train(lstm, sentence)
141.891098022 aoyekppy mocalmz xk atc jlg oaddk
128.925964355 hempeyud ki
121.445785522 qpveti fyobec ztmr eioknnueh ehecdvabxmc ydpmdm
110.670722961 z buws lmy vvrw
93.5055999756 vueoa cprlnkrd o ocazk nb olegiep o fftr t
82.1586227417 zj rvsr oej c toz bnarreow fffj
67.430847168 rzfik qoyc ohe hqe oea uitet ou udjkpme oak kdk oe fbu kcz fox dfoprl too o rxat luurnfowrrtj rbtram to url xlj okrr ooe otm hcy roab llsg doy ifzw rrbow rbowwb oke jxpee
54.9477920532 ba uiy doge she ueeze oejv
43.3301696777 qquc crgibbroej oxne ove rr
34.4687461853 uqckk owrbfo og uouk doge l
25.5408306122 reuk lfr own fox juamd ov
18.9417610168 qojn doo broww boan jover txe zacy moen crlw numk fox joge overwa trez quqk browx ox ruor oro fow j uoez kon fror bowe luccmd ogwr foy jodmoed ox
13.1646575928 qucy dov
9.46595668793 wiuuik brttxl laed over tre lazy dog
5.6522898674 rukc irown fox juaped over the lazy dov
3.38144731522 a quick brown fox jumver the lazy dog
1.80010521412 a bfoin fox jumped ovk fox luick brown fox jumped over the lazy dog
1.30616080761 a quic brownn fox jumped over the lazy dog
1.02201879025 a quick brown fox jumped over the lazy dog
0.83735615015 qucck brown fox jcmped over the lazy dog
0.708056390285 a quickz brown fox jumped over the lazy dog
0.612650871277 a quick brown fox jumped over the lazy dog
0.539469838142 a quick brown fox jumped over thel lazy dog
0.481610894203 va quick brown fox jumped over the lazy dog
0.434762001038 a quuck dovtbown fox jumped over the lazy dog
0.396079242229 a quick brown fox jumped over the lazy dog
0.363606244326 a quick brown fox jumped over the laza dog
0.335973978043 a quick brown fox jumped over the lazy dog
0.312186658382 a quick brown fox jumped over the lazy dog
0.291498303413 a quick brown fox qu
0.273335546255 a quick brown fox jumped ove
0.257278442383 a quick brown fox jumped over the lazy dog
0.242971763015 a quick brown fox jumped over the lazy dog
0.230153128505 a quick brown fox jumped over the lazy dog
0.218599274755 a quick brown fox jumped over the lazy dog
0.208135351539 a quick brown fox jumped over the lazy dog
0.198613137007 a quick brown fox jumped over tie lazy dog
0.189909905195 a quick brown fox jumped over the lazy dog
0.181928783655 a quick brown fox jumped over the lazy dog
0.174587100744 a quick brown fox jumped over the lazy dog
The model seem to learn the sentence quite well.
Somewhat surprisingly, the Simple-RNN model learn quicker than the LSTM!
How can that be?
The answer is that we are cheating a bit. The sentence we are trying to learn has each letter-bigram exactly once. This means a simple trigram model can memorize it very well.
Try it out with more complex sequences.
In [21]:
train(srnn, "these pretzels are making me thirsty")
332.651580811 a quick brown fox jumped over the lazy dog
133.209350586 a quick brown fox jumped over the lazy doe hu yum xd the
65.0720596313 azquick brown fox jumped over ohe iog
31.5592880249 a quick brown fox jumpedrovtretpede pretzelz are makink ma tui idmilt
13.2322559357 theve prwtumpede mhxtjaypny mreticv
1.87829053402 thele pretzelb mre laki loet dre za tuiri mtoina ma qui irwt ere sa taetsdaca qamtuioe ma ick mrolnn mhetsirstyyza qa luijuoethetsepsaaya quirk brmtze ehersjlyaa aumu orkrbtoeqz lrea quijk jrowza quiquihi sakiny mr tui ss thels theqetursy famtzi maethehe iretza lamqzd zretsels area qhirk browna yhetza quirkt rxkwn mox ja isi mq thirsty
0.680327475071 these pretzels are makind me thirsty
0.176128521562 these pretzels are making me thirsty
0.126334354281 these pretzels are making me thirsty
0.10075186193 these pretzels are making me thirsty
0.0846510156989 these pretzels are making me thirsty
0.0734022557735 these pretzels are making me thirsty
0.0650328546762 these pretzels are making me thirsty
0.0585154108703 these pretzels are making me thirsty
0.0532807298005 these pretzels are making me thirsty
0.0489665567875 these pretzels are making me thirsty
0.0453444086015 these pretzels are making me thirsty
0.0422535128891 these pretzels are making me thirsty
0.0395833179355 these pretzels are making me thirsty
0.0372485220432 these mretzels are making me thirsty
0.0351839251816 these pretzels are making me thirsty
0.0333509668708 these pretzels are making me thirsty
0.0317104011774 these pretzels are making me thirsty
0.0302277039737 these pretzels are making me thirsty
0.0288887582719 these pretzels are making me thirsty
0.0276643745601 these pretzels are making me thirsty
0.0265435613692 these pretzels are making me thirsty
0.0255212895572 these pretzels are making me thirsty
0.0245705824345 these pretzels are making me thirsty
0.0236932244152 these pretzels are making me thirsty
0.0228785891086 these pretzels are making me thirsty
0.0221205893904 these pretzels are making me thirsty
0.0214090794325 these pretzels are making me thirsty
0.0207556784153 these pretzels are making me thirsty
0.0201329570264 these pretzels are making me thirsty
0.0195484217256 these pretzels are making me thirsty
0.0190003421158 these pretzels are making me thirsty
0.0184785164893 these pretzels are making me thirsty
0.0179911740124 these pretzels are making me thirsty
0.0175334792584 these pretzels are making me thirsty
DyNet Autobatch¶
Friends don’t let friends write batching code¶
Modern hardware processors (CPUs and GPUs) can use parallelism to a great extent. So batching is good for speed. But it is so annoying to write batching code for RNNs or more complex architectures. You must take care of padding, and masking, and indexing, and that’s just for the easy cases… Not any more!
We’ve added a feature to DyNet that will transform the way you think about and run batching code. The gist of it is: you aggregate a large enough computation graph to make batching possible. DyNet figures out the rest, and does the batching for you.
In what follows, we show some examples of non-batched DyNet code, and then move on to show the batched version.
In order to enable auto-batching support, simply add
--dynet-autobatch 1
to the commandline flags when running a DyNet
program. Check out the paper or
read on for more details!
Dynamic Graphs, Non-batched¶
Let’s look at some examples of non-batched code, and how simple they are to write in DyNet.
Our first example will be an acceptor LSTM, that reads in a sequence of vectors, passes the final vector through a linear layer followed by a softmax, and produces an output.
In [1]:
import dynet as dy
import numpy as np
In [2]:
# acceptor LSTM
class LstmAcceptor(object):
def __init__(self, in_dim, lstm_dim, out_dim, model):
self.builder = dy.VanillaLSTMBuilder(1, in_dim, lstm_dim, model)
self.W = model.add_parameters((out_dim, lstm_dim))
def __call__(self, sequence):
lstm = self.builder.initial_state()
W = self.W.expr() # convert the parameter into an Expession (add it to graph)
outputs = lstm.transduce(sequence)
result = W*outputs[-1]
return result
In [3]:
# usage:
VOCAB_SIZE = 1000
EMBED_SIZE = 100
m = dy.Model()
trainer = dy.AdamTrainer(m)
embeds = m.add_lookup_parameters((VOCAB_SIZE, EMBED_SIZE))
acceptor = LstmAcceptor(EMBED_SIZE, 100, 3, m)
# training code
sum_of_losses = 0.0
for epoch in range(10):
for sequence,label in [((1,4,5,1),1), ((42,1),2), ((56,2,17),1)]:
dy.renew_cg() # new computation graph
vecs = [embeds[i] for i in sequence]
preds = acceptor(vecs)
loss = dy.pickneglogsoftmax(preds, label)
sum_of_losses += loss.npvalue()
loss.backward()
trainer.update()
print sum_of_losses / 3
sum_of_losses = 0.0
print "\n\nPrediction time!\n"
# prediction code:
for sequence in [(1,4,12,1), (42,2), (56,2,17)]:
dy.renew_cg() # new computation graph
vecs = [embeds[i] for i in sequence]
preds = dy.softmax(acceptor(vecs))
vals = preds.npvalue()
print np.argmax(vals), vals
[ 1.1042192]
[ 1.03213656]
[ 0.97442627]
[ 0.91803074]
[ 0.86030102]
[ 0.79953943]
[ 0.73457642]
[ 0.66490026]
[ 0.59101043]
[ 0.51482052]
Prediction time!
1 [ 0.06114297 0.75843614 0.18042086]
1 [ 0.25732863 0.37167525 0.37099609]
1 [ 0.1679846 0.61701268 0.21500272]
This was simple. Notice how each sequence has a different length, but
its OK, the LstmAcceptor
doesn’t care. We create a new graph for
each example, at exactly the desired length.
Similar to the LstmAcceptor
, we could also write a TreeRNN
that
gets as input a tree structure and encodes it as a vector. Note that the
code below is missing the support code for rerpesenting binary trees and
reading trees from bracketed notation. All of these, along with the more
sophisticated TreeLSTM
version, and the training code, can be found
here.
In [6]:
class TreeRNN(object):
def __init__(self, model, word_vocab, hdim):
self.W = model.add_parameters((hdim, 2*hdim))
self.E = model.add_lookup_parameters((len(word_vocab),hdim))
self.w2i = word_vocab
def __call__(self, tree): return self.expr_for_tree(tree)
def expr_for_tree(self, tree):
if tree.isleaf():
return self.E[self.w2i.get(tree.label,0)]
if len(tree.children) == 1:
assert(tree.children[0].isleaf())
expr = self.expr_for_tree(tree.children[0])
return expr
assert(len(tree.children) == 2),tree.children[0]
e1 = self.expr_for_tree(tree.children[0], decorate)
e2 = self.expr_for_tree(tree.children[1], decorate)
W = dy.parameter(self.W)
expr = dy.tanh(W*dy.concatenate([e1,e2]))
return expr
Enter batching¶
Now, let’s add some minibatching support. The way we go about it is very simple: Your only responsibility, as a programmer, is to build a computation graph with enough material to make batching possible (i.e., so there is something to batch). DyNet will take care of the rest.
Here is the training and prediction code from before, this time writen
with batching support. Notice how the LstmAcceptor
did not change,
we just aggregate the loss around it.
In [5]:
# training code: batched.
for epoch in range(10):
dy.renew_cg() # we create a new computation graph for the epoch, not each item.
# we will treat all these 3 datapoints as a single batch
losses = []
for sequence,label in [((1,4,5,1),1), ((42,1),2), ((56,2,17),1)]:
vecs = [embeds[i] for i in sequence]
preds = acceptor(vecs)
loss = dy.pickneglogsoftmax(preds, label)
losses.append(loss)
# we accumulated the losses from all the batch.
# Now we sum them, and do forward-backward as usual.
# Things will run with efficient batch operations.
batch_loss = dy.esum(losses)/3
print batch_loss.npvalue() # this calls forward on the batch
batch_loss.backward()
trainer.update()
print "\n\nPrediction time!\n"
# prediction code:
dy.renew_cg() # new computation graph
batch_preds = []
for sequence in [(1,4,12,1), (42,2), (56,2,17)]:
vecs = [embeds[i] for i in sequence]
preds = dy.softmax(acceptor(vecs))
batch_preds.append(preds)
# now that we accumulated the prediction expressions,
# we run forward on all of them:
dy.forward(batch_preds)
# and now we can efficiently access the individual values:
for preds in batch_preds:
vals = preds.npvalue()
print np.argmax(vals), vals
[ 0.46247479]
[ 0.43548316]
[ 0.40905878]
[ 0.38335174]
[ 0.35849127]
[ 0.3345806]
[ 0.31169581]
[ 0.28988609]
[ 0.26917794]
[ 0.24957809]
Prediction time!
1 [ 0.00736407 0.95775431 0.03488157]
2 [ 0.2252606 0.36341026 0.41132909]
1 [ 0.05491769 0.85925961 0.08582276]
Doing the same thing for the TreeRNN example is trivial: just aggregate the expressions from several trees, and then call forward. (In fact, you may receive a small boost from the auto-batching feature also within a single tree, as some computation can be batched there also.)
Comparison to manual batching¶
We compared the speed of automatic-batching as shown above to a manualy crafted batching code, in a setting in which manual-batching excels: BiLSTM tagging where all the sentences are of the exact same length. Here, automatic batching improved the per-sentence computation time from 193ms to 16.9ms on CPU and 54.6ms to 5.03ms on GPU, resulting in an approximately 11-fold increase in sentences processed per second (5.17->59.3 on CPU and 18.3->198 on GPU). However, manual batching is still 1.27 times faster on CPU, and 1.76 times faster on a GPU.
The speed in favor of manual batching seem to come mostly from the time it takes to create the computation graph itself: in manual batching we are creating a single graph with many inputs, while with automatic batching we essentially build many copies of the same graph for each batch. Should you use manual batching then? In situations in which it is very natural, like in this artificial one, sure! But in cases where manual batching is not so trivial (which is most cases, see some examples below), go ahead and use the automatic version. It works.
You can also run automatic batching on top of manually batched code. When doing this, we observe another 10% speed increase above the manual batched code, when running on the GPU. This is because the autobatching engine managed to find and exploit some additional batching opportunities. On the CPU, we did not observe any gains in this setting, but also did not observe any losses.
How big is the win?¶
So the examples above are rather simple, but how does this help on actual applications? We’ve run some experiments on several natural language processing tasks including POS tagging with bidirectional LSTMs, POS tagging with BiLSTMs that also have character embeddings (which is harder to batch), tree-structured neural networks, and a full-scale transition-based dependency parser. Each of these has a batch size of 64 sentences at a time, without worrying about length balancing or anything of that sort. As you can see from the results below on sentences/second, auto-batching gives you healthy gains of 3x to 9x over no auto-batching. This is with basically no effort required!
Task | No Autobatch (CPU) | Autobatch (CPU) | No Autobatch (GPU) | Autobatch (GPU) |
---|---|---|---|---|
BiLSTM | 16.8 | 156 | 56.2 | 367 |
BiLSTM w/ char | 15.7 | 132 | 43.2 | 275 |
TreeNN | 50.2 | 357 | 76.5 | 661 |
Transition Parser | 16.8 | 61.2 | 33.0 | 90.1 |
If you want to try these benchmarks yourself, take a look at the
...-bulk
programs in the
dynet-benchmark
repository.
In the graph below you can see the number of sentences/second for training the transition-based parser with various batch sizes, on the GPU, CPU, and CPU witk MKL enabled:
The following graph shows the number of sentences/second for the Tree-LSTM model for various batch sizes, and also compares to TensorFlow Fold implementation, which is another proposed solution for batching hard-to-batch architectures. Note that DyNet autobatching comfortably wins over TensorFlow fold for both GPU and CPU, with CPU being more efficient than GPU for smaller sized batches.
Miscellaneous tips¶
Should you always use batching?¶
It depends. In prediction time, batching is a pure win in terms of speed. In training time, the sentences/second throughput will be much better—but you will also have less parameter updates, which may make overall training slower. Experiment with different batch sizes to find a good tradeoff between the two.
Length-balanced batches?¶
It is common knowledge when writing batched code that one should arrange the batches such that all examples within the batch are of the same size. This is crucial for static frameworks and manual batching, as it reduces the need for padding, masking and so on. In our framework, this is not needed. However, you may still win some speed by having relatively-balanced batches, because more batching opportunities will become available.
Tips for effective autobatching¶
As we said above, our only rule is “create a graph with enough material
for the autobatcher to work with”. In other words, it means delaying the
call to forward()
(or to value()
, npvalue()
,
scalar_value()
…) as much as possible. Beyond that, things should
be transparent.
However, knowing some technicalities of DyNet and how forward
works
can help you avoid some pitfals. So here is a brief overview:
- The core building block of dynet are
Expression
objects. Whenever you create a newExpression
, you extend the computation graph. - Creating an
Expression
does not entail a forward computation. We only evaluate the graph when specifically asked for it. - Calls to
e.forward()
,e.value()
,e.npvalue()
,e.scalar_value()
, will run forward computation up to that expression, and return a value. - These calls will compute all the expressions that were added to the
graph before
e
. These intermediary results will be cached. - Asking for values for (or calling forward on) earlier expressions, will reuse the cached values.
- You can extend the graph further after calling forward. Later calls will compute the graph delta.
So, based on this knowledge, here is the rule:
If you created several expressions, and want to get the values for them, call forward on the last expression first, and then on the previous ones.
Doing it the other way around (evaluting the expressions in the order
they were created) will hinder batching possibilities, because it will
compute only a small incremental part of forward for each expression. On
the other hand, if you run forward on the last expression first, the
entire computation will happen in one chunk, batching when possible.
Getting calling npvalue()
on the earlier expressions will then
return the already computed values.
If you created a bunch of expressions and are not sure which one is the
latest, you could just call the special list
version of forward:
dy.forward([e1,e2,...,en])
and it will figure it out for you.
Loose ends¶
Auto-batching in DyNet works and is stable. However, some of the less common operations are not yet batched. If you have an example where you think you should be getting a nice boost from autobatching but you don’t, it is most likely that you are using a non-batched operation. In any case, let us know via an issue in github, and we’ll investigate this.
Saving Models¶
DyNet provides the ability to save and restore model parameters. The user has several options for saving and restoring parameters.
Saving an entire model¶
In the first option, complete ParameterCollection
object is saved. At loading time, the user should define and allocate the same parameter objects that were present in the model when it was saved, and in the same order (this usually amounts to having the same parameter creation called by both code paths), and then call populate
on the ParameterCollection
object containing the parameters that should be loaded.
import dynet as dy
# save
m = dy.ParameterCollection()
a = m.add_parameters(100)
b = m.add_lookup_parameters((10, 100))
c = m.add_parameters(1000)
m.save("/tmp/tmp.model")
# load
m2 = dy.ParameterCollection()
x = m2.add_parameters(100);
y = m2.add_lookup_parameters((10, 100))
z = m2.add_parameters(1000)
m.populate("/tmp/tmp.model")
Partial Saving And Loading (Low-level API)¶
(This API follows the C++ partial saving and loading paradigm. See below for a higher level pythonic API.)
In some cases it is useful to save only a subset of parameter objects (for example, if users wish to load these in a pretraining setup). Here, Parameter
or LookupParameter
objects can be saved explicitly. User could also specify keys for partial saving and loading.
import dynet as dy
# save
m1 = dy.ParameterCollection() # m1.name() == "/"
m2 = dy.ParameterCollection() # m2.name() == "/"
m3 = m1.add_subcollection("m3") # m3.name() == "/m3/"
a = m1.add_parameters(10, name="a") # a.name() == "/a"
L = m1.add_lookup_parameters((10, 2), name="la") # L.name() == "/la"
param_b = m2.add_parameters((3, 7)) # param_b.name() == "/_0"
param_c = m3.add_parameters((3, 7), name="pc") # param_c.name() == "/m3/pc"
param_d = m3.add_parameters((3, 7)) # param_d.name() == "/m3/_0"
L.save("/tmp/tmp.model", "/X") # ignores L.name(), saves L under "/X"
a.save("/tmp/tmp.model", append=True) # uses a.name()
param_c.save("/tmp/tmp.model", append=True)
param_b.save("/tmp/tmp.model", append=True)
param_d.save("/tmp/tmp.model", append=True)
# load
m = dy.ParameterCollection()
a2 = m.add_parameters(10)
L2 = m.add_lookup_parameters((10, 2))
c = m.add_parameters((3,7))
L2.populate("/tmp/tmp.model", "/X")
a.populate("/tmp/tmp.model", "/a")
c.populate("/tmp/tmp.model", "/m3/pc")
(See the documentation of ParameterCollection
for further information about sub_collections
and the use of collection hierarchies )
One can also save and load builder objects using their internal parameter collection.
# save
lstm = dy.LSTMBuilder(2, 100, 100, m1)
pc = lstm.param_collection() # pc.name() == "/lstm-builder/"
lstm2 = dy.LSTMBuilder(2, 50, 50, m1)
pc2 = lstm2.param_collection() # pc2.name() == "/lstm-builder_1/"
pc2.save("/tmp/tmp.model",append=False)
pc.save("/tmp/tmp.model",append=True)
# load
lstm2 = dy.LSTMBuilder(2, 50, 50, m)
lstm2.param_collection().populate("/tmp/tmp.model", "/lstm-builder_1/")
lstm = dy.LSTMBuilder(2, 100, 100, m)
lstm.param_collection().populate("/tmp/tmp.model", "/lstm-builder/")
Partial Saving And Loading (High-level API)¶
Use the module level dy.save(basename, lst)
and dy.load(basename, param_collection)
methods.
dy.save
gets a base filename and a list of saveable objects (see below), and saves them to file.
dy.load
gets a base filename and a parameter collection (model), and returns a
list of objects, in the same order that were passed to dy.save
. The paramters
of the objects are added to the model.
Notice that you do not need to specify sizes when loading.
import dynet as dy
pc = dy.ParameterCollection()
W = pc.add_parameters((100,50))
E = pc.add_lookup_parameters((1000,50))
builder_a = dy.LSTMBuilder(2, 50, 50, pc)
builder_b = dy.LSTMBuilder(2, 100, 100, pc)
dy.save("/tmp/model", [E, builder_b, W])
# this will create two files, "/tmp/model.data" and "/tmp/model.meta"
# then, when loading:
pc2 = dy.ParameterCollection()
E2, builder2, W2 = dy.load("/tmp/model", pc2)
What can be saved?¶
Each object in lst
must be one of the following:
- Parameter
- LookupParameter
- One of the built-in types (VanillaLSTMBuilder, LSTMBuilder, GRUBuilder, SimpleRNNBuilder, BiRNNBuilder)
- A type adhering to the following interface:
- has a
.param_collection()
method returning a ParameterCollection object with the parameters in the object.- has a pickleable
.spec
property with items describing the object- has a
.from_spec(spec, model)
static method that will create and return a new instane of the object with the needed parameters/etc.
Note, the built-in types in (3) above can be saved/loaded this way simply because they support this interface.
behind the scenes:
- for each item, we write to
basename.meta
:
- if its a Parameters/ParameterCollection:
- its type and full name.
- if its a builder:
- its class, its spec, the full name of its parameters collection.
- the associated parameters/sub-collection is then saved to
.data
Example of a user-defined saveable type:¶
# Example of a user-defined saveable type.
class OneLayerMLP(object):
def __init__(self, model, num_input, num_hidden, num_out, act=dy.tanh):
pc = model.add_subcollection()
self.W1 = pc.add_parameters((num_hidden, num_input))
self.W2 = pc.add_parameters((num_out, num_hidden))
self.b1 = pc.add_parameters((num_hidden))
self.b2 = pc.add_parameters((num_out))
self.pc = pc
self.act = act
self.spec = (num_input, num_hidden, num_out, act)
def __call__(self, input_exp):
W1 = dy.parameter(self.W1)
W2 = dy.parameter(self.W2)
b1 = dy.parameter(self.b1)
b2 = dy.parameter(self.b2)
g = self.act
return dy.softmax(W2*g(W1*input_exp + b1)+b2)
# support saving:
def param_collection(self): return self.pc
@staticmethod
def from_spec(spec, model):
num_input, num_hidden, num_out, act = spec
return OneLayerMLP(model, num_input, num_hidden, num_out, act)
And for the usage:
import dynet as dy
m = dy.ParameterCollection()
# create an embedding table.
E = m.add_lookup_parameters((1000,10))
# create an MLP from 10 to 4 with a hidden layer of 20.
mlp = OneLayerMLP(m, 10, 20, 4, dy.rectify)
# use them together.
output = mlp(E[3])
# now save the model:
dy.save("basename",[mlp, E])
# now load:
m2 = dy.ParameterCollection()
mlp2, E2 = dy.load("basename", m2)
output2 = mlp2(E2[3])
import numpy
assert(numpy.array_equal(output2.npvalue(), output.npvalue()))
File format¶
Currently, DyNet only supports plain text format. The native format is quite simple so very readable. The model file is consist of basic storage blocks. A basic block starts with a first line of meta data information: #object_type# object_name dimension block_size
and the remaining part of real data. During loading process, DyNet uses meta data lines to locate the objects user wants to load.
In the pythonic high-level partial saving/loading API, the .data
file adheres to
the format above, while the .meta
file conains information on objects types and sizes (for the specifics of the .meta
file format see code of _save_one
and _load_one
in _dynet.pyx
).
A more comprehensive tutorial can be found here (EMNLP 2016 tutorial).
Command Line Options¶
All programs using DyNet have a few command line options. These must be specified at the very beginning of the command line, before other options.
--dynet-mem NUMBER
: DyNet runs by default with 512MB of memory, which is split evenly for the forward and backward steps, parameter storage as well as scratch use. This will be expanded automatically every time one of the pools runs out of memory. By setting NUMBER here, DyNet will allocate more memory immediately at the initialization stage. Note that you can also individually set the amount of memory for forward calculation, backward calculation, parameters, and scratch use by using comma separated variables--dynet-mem FOR,BACK,PARAM,SCRATCH
. This is useful if, for example, you are performing testing and don’t need to allocate any memory for backward calculation.--dynet-weight-decay NUMBER
: Adds weight decay to the parameters, which modifies each parameter w such that w *= (1-weight_decay) after every update. This is similar to L2 regularization, but different in a couple ways, which are noted in detail in the “Unorthodox Design” section.--dynet-autobatch NUMBER
: Turns on DyNet’s automatic operation batching capability. This makes it possible to speed up computation with a minimum of work. More information about this functionality can be found here.--dynet-gpus NUMBER
: Specify how many GPUs you want to use, if DyNet is compiled with CUDA.--dynet-gpu
: Specify whether to use GPU or not. Note that it is an option for Python programs.--dynet-devices CPU,GPU:1,GPU:3,GPU:0
: Specify the CPU/GPU devices that you want to use. You can the physical ID for GPU and can not specify the ID for CPU. This is an useful option working together with your multi-device code. Currently, DyNet needs you to specify the device ID explictly. The option--dynet-gpu-ids
is deprecated.--dynet-profiling NUMBER
: Will output information about the amount of time/memory used by each node in the graph. Profile level with0, 1
and2
.
Debugging/Reporting Issues¶
There are a number of tools to make debugging easier in DyNet. In addition, we welcome any questions or issues, and will be able to respond most effectively if you follow the guidelines below.
Debugging Tools¶
Visualization¶
It is possible to create visualizations of the computation graph by calling the print_graphviz()
function, which can be helpful to debug. When this functionality is used in Python, it is necessary to add the command line argument --dynet-viz
. In Python, there is also a print_text_graphviz()
function which will be less pretty than the print_graphviz()
function, but doesn’t require the command line flag.
Immediate Computation¶
In general, DyNet performs symbolic execution. This means that you first create the computation graph, then the computation will actually be performed when you request a value using functions such as forward()
or value()
. However, if an error occurs during calculation, this can be hard to debug because the error doesn’t occur immediately where the offending graph node is created. To make debugging simpler, you can use immediate computing mode in dynet. In this mode, every computation gets executed immediately, just like imperative programming, so that you can find exactly where goes wrong.
In C++, you can switch to the immediate computing mode by calling ComputationGraph::set_immediate_compute as follows:
ComputationGraph cg;
cg.set_immediate_compute(true);
Further, dynet can automatically check validity of your model, i.e., detecting Inf/NaN, if it is in immediate computing mode. To activate checking validity, you can add the following code after switching to immediate computing mode.
cg.set_check_validity(true);
In Python, these values can be set by using optional arguments to the renew_cg()
function as follows:
dy.renew_cg(immediate_compute = True, check_validity = True)
Debug Builds¶
By default, DyNet is built with all optimization enabled.
You can build DyNet without optimizations by adding
-DCMAKE_BUILD_TYPE=Debug
to the cmake command
cd dynet
mkdir build
cd build
cmake .. -DCMAKE_BUILD_TYPE=Debug
make -j8 # replace 8 properly
Note: pass other cmake options based on your environment.
Debugging Crashes¶
Build with ASan¶
If you’re on Linux or macOS, you can build DyNet with
AddressSanitizer
(aka ASan). ASan is a memory error detector for C/C++. It’s useful for debugging
bugs or crashes caused by memory errors such as use-after-free, heap buffer overflow,
stack buffer overflow. By running ASan-enabled tests or programs, ASan finds memory
errors at runtime. To enable ASan, add -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-fsanitize=address"
to the cmake command:
cd dynet
mkdir build-asan
cd build-asan
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-fsanitize=address"
make -j8
Please see the official wiki for the details.
CAUTION: Please do not install ASan enabled libraries or programs under root partition. You might have a bad time.
Debugging Threading Issues¶
Build with TSan¶
Linux/macOS only.
If you’re on Linux or macOS, you can build DyNet with ThreadSanitizer (aka TSan). TSan is a data race error detector for C/C++. It finds data races at runtime just like ASan. Please see the official wiki for more details.
By running TSan-enabled tests or programs, TSan finds data races at runtime.
To enable TSan, add -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-fsanitize=thread"
to the cmake command:
cd dynet
mkdir build-tsan
cd build-tsan
cmake .. -DCMAKE_BUILD_TYPE=Debug -DCMAKE_CXX_FLAGS="-fsanitize=thread"
CAUTION: Please do not install TSan enabled libraries or programs under root partition. You might have a bad time.
Asking Questions/Reporting Bugs¶
Feel free to contact the dynet-users group or file an issue on github with any questions or problems.
(If you subscribe to dynet-users
and want to receive email make sure to select “all email” when you sign up.)
When you have an issue, including the following information in your report will greatly help us debug:
- What is the error? Copy and paste the error message.
- What is your environment? Are you running on CPU or GPU? What OS? If the problem seems to be related to a specific library (CUDA, Eigen), what version of that library are you using?
- If possible, it will be really really helpful if you can provide a minimal code example that will cause the problem to occur. This way the developers will be able to reproduce the problem in their own environment.
If you have a build problem and want to debug, please run
make clean
make VERBOSE=1 &> make.log
then examine the commands in the make.log
file to see if anything
looks fishy. If you would like help, send this make.log
file via the
“Issues” tab on GitHub, or to the dynet-users mailing list.
Python Reference Manual¶
Dynet global parameters¶
DynetParams¶
-
class
dynet.
DynetParams
¶ This object holds the global parameters of Dynet
This is useful if you want to specify the global dynet parameters (memory, random seed…) programmatically, for example in a notebook.
import _dynet
You can then declare and use a
DynetParams
object# Declare a DynetParams object dyparams = dy.DynetParams() # Fetch the command line arguments (optional) dyparams.from_args() # Set some parameters manualy (see the command line arguments documentation) dyparams.set_mem(2048) dyparams.set_random_seed(666) # Initialize with the given parameters dyparams.init() # or init_from_params(dyparams)
You can also use
dynet_config
object in your script to specify the device usage and the global dynet parameters (memory, random seed…) beforeimport dynet
:import dynet_config # Declare GPU as the default device type dynet_config.set_gpu() # Set some parameters manualy dynet_config.set(mem=4,random_seed=9) # Initialize dynet import using above configuration in the current scope import dynet
Don’t forget to initialize with
dyparams.init()
, otherwise dynet will raise an error.-
from_args
(shared_parameters=None)¶ Gets parameters from the command line arguments
You can still modify the parameters after calling this. See the documentation about command line arguments for more details
Keyword Arguments: shared_parameters ([type]) – [description] (default: None)
-
from_config
(conf)¶ Set parameters from config object:
- Attributes of conf object:
- mem, seed, autobatch, profiling, weight_decay, shared_params, requested_gpus, gpu_mask
-
init
()¶ Initialize dynet with the current dynetparams object.
This is one way, you can’t uninitialize dynet
-
set_autobatch
(autobatch)¶ Activate autobatching
Parameters: autobatch (bool) – Set to True
to activate autobatching
-
set_mem
(mem)¶ Set the memory allocated to dynet
The unit is MB
Parameters: mem (number) – memory size in MB
-
set_profiling
(profiling)¶ Activate autobatching debug
Parameters: profiling (int) – Set to a value > 0 to activate profiling
-
set_random_seed
(random_seed)¶ Set random seed for dynet
Parameters: random_seed (number) – Random seed
-
set_requested_gpus
(requested_gpus)¶ Number of requested gpus
Parameters: requested_gpus (number) – number of requested gpus
Shared parameters
Parameters: shared_parameters (bool) – shared parameters
-
set_weight_decay
(weight_decay)¶ Set weight decay parameter
Parameters: weight_decay (float) – weight decay parameter
-
Initialization functions¶
-
dynet.
init
(shared_parameters=None)¶ Initialize dynet
Initializes dynet from command line arguments. Do not use after import dynet
Keyword Arguments: shared_parameters (bool) – [description] (default: None)
-
dynet.
init_from_params
(params)¶ Initialize from DynetParams
Same as
params.init()Parameters: params (DynetParams) – dynet parameters
ParameterCollection and Parameters¶
ParameterCollection¶
-
class
dynet.
ParameterCollection
(parent=None)¶ A ParameterCollection holds Parameters. Use it to create, load and save parameters.
(It used to be called Model in previous versions of DyNet, and Model is still an alias for ParameterCollection.)
A ParameterCollection is a container for Parameters and LookupParameters.
dynet.Trainer objects take ParameterCollection objects that define which parameters are being trained.
The values of the parameters in a collection can be persisted to and loaded from files.
- Hierarchy:
- The parameter collections can be nested, where each collection can hold zero or more sub-collection, which are also ParameterCollection objects. Each (sub-)collection contains the parameters in it and in all the (sub-)collections below it.
- Naming:
Parameters, LookupParameters and ParameterCollections have associated string names. The names can be accessed using the .name() method.
The names are used for identifying the parameters and the collection hierarchy when loading from disk, and in particular when loading only a subset of the objects in a saved file.
The name of a parameter, lookup parameter or sub-collection is unique within a ParameterCollection, and reflects the hierarchy structure.
One can supply an optional informative name when creating the parameter or sub-collection. The supplied names are then appended with running index to avoid name clashes. The .name() method returns the full name of an object, including the appended index and its location within the collection hierarchy. The user-supplied names cannot inclue the characters / (which is used as a hierarchy separator) or
_
(which is used as an index separator).
-
add_lookup_parameters
(dim, init=None, name='', device='')¶ Add a lookup parameter to the ParameterCollection
Parameters: dim (tuple) – Shape of the parameter. The first dimension is the vocab size
Keyword Arguments: - init (dynet.PyInitializer) – Initializer (default: GlorotInitializer)
- name (string) – Optional name for this parameter (default: “”)
- device (string) – Optional device name for this parameter (default: “”, default device)
Returns: Created LookupParameter
Return type:
-
add_parameters
(dim, init=None, name='', device='')¶ Add a parameter to the ParameterCollection
Parameters: dim (tuple) – Shape of the parameter
Keyword Arguments: - init (dynet.PyInitializer) – Initializer (default: GlorotInitializer)
- name (string) – Optional name for this parameter (default: “”)
- device (string) – Optional device name for this parameter (default: “”, default device)
Returns: Created Parameter
Return type:
-
add_subcollection
(name=None)¶ Creates a sub-collection of the current collection, and returns it.
A sub-collection is simply a ParameterCollection object which is tied to a parent collection. ParameterCollections can be nested to arbitraty depth.
Sub-collections are used for grouping of parameters, for example if one wants to train only a subset of the parameters, one can add them in a subcollection and pass the subcollection to a trainer. Similarly, for saving (or loading) only some of the parameters, one can save/populate a sub-collection.
Sub-collections are used inside builder objects (such as the LSTMBuilder): The builder creates a local sub-collection and adds parameters to it instead of to the global collection that is passed to it in the constructor. This way, the parameters participating in the builder are logically grouped, and can be saved/loaded/trained seperately if needed.
Parameters: name (string) – an optional name for the sub-collection. Keyword Arguments: name (string) – Optional name for this sub-collection (default: “”) Returns: (dynet.ParameterCollection) a parameter collection.
-
load_lookup_param
(fname, key)¶ Loads a named lookup-parameter from a file, adds it to the collection, and returns the loaded parameter.
Parameters: - fname (string) – the file name to read from.
- key (string) – the full-name of the lookup parameter to read.
Returns: (dynet.LookupParameters) The LookupParameters object.
-
load_param
(fname, key)¶ Loads a named parameter from a file, adds it to the collection, and returns the loaded parameter.
Parameters: - fname (string) – the file name to read from.
- key (string) – the full-name of the parameter to read.
Returns: (dynet.Parameters) The Parameters object.
-
lookup_parameters_from_numpy
(array, name='', device='')¶ Create LookupParameters from numpy array
Parameters: - array (np.ndarray) – Numpy array. rows: vocab_size, cols: dims.
- name (string) – optional name for this parameter.
- device (string) – Optional device name for this parameter (default: “”, default device)
Returns: LookupParameter
Return type:
-
lookup_parameters_list
()¶ Returns list of all looku parameters in the collection
Returns: All dy.LookupParameters in the collection Return type: (list)
-
name
()¶ Return the full name of this collection.
-
parameters_from_numpy
(array, name='', device='')¶ Create parameter from numpy array
Parameters: - array (np.ndarray) – Numpy array
- name (string) – optional name for this parameter.
- device (string) – Optional device name for this parameter (default: “”, default device)
Returns: Parameter
Return type:
-
parameters_list
()¶ Returns list of all parameters in the collection
Returns: All dy.Parameters in the collection Return type: (list)
-
populate
(fname, key='')¶ Populate the values of all parameters in this collection from file.
This only populates the values of existing parameters, and does not add parameters to the collection. Thus, the content of the file and the parameters in this collection must match. One should make sure to add to the collection the same parameters (and in the same order) before calling populate, as the ones that were added before calling save.
Parameters: fname (string) – file name to read parameter values from.
-
save
(fname, name='', append=False)¶ Save the values of all parameters in this collection to file.
Parameters: fname (string) – file name to save into.
Parameters and LookupParameters¶
-
class
dynet.
Parameters
¶ Parameters class
Parameters are things that are optimized. in contrast to a system like Torch where computational modules may have their own parameters, in DyNet parameters are just parameters.
-
as_array
()¶ Return as a numpy array.
Returns: values of the parameter Return type: np.ndarray
-
clip_inplace
(left, right)¶ Clip the values in the parameter to a fixed range [left, right] (in place)
Parameters: arr (np.ndarray) – Scale
-
expr
(update=True)¶ Returns the parameter as an expression
This is the same as calling
dy.parameter(param)Parameters: update (bool) – If this is set to False, the parameter won’t be updated during the backward pass Returns: Expression of the parameter Return type: Expression
-
grad_as_array
()¶ Return gradient as a numpy array.
Returns: values of the gradient w.r.t. this parameter Return type: np.ndarray
-
is_updated
()¶ check whether the parameter is updated or not
Returns: Update status Return type: bool
-
name
()¶ Return the full name of this parameter.
-
populate
(fname, key)¶ Populate the values of this Parameters object from the parameter named key in the file fname. The sizes of saved parameters and this object must match.
Parameters: - fname (string) – the name of a file to load from.
- key (string) – the parameter to read from the file.
-
scale
(s)¶ Scales the parameter
Parameters: s (float) – Scale
-
scale_gradient
(s)¶ Scales the gradient
Parameters: s (float) – Scale
-
set_updated
(b)¶ Set parameter as “updated”
Parameters: b (bool) – updated status
-
set_value
(arr)¶ Set value of the parameter
-
shape
()¶ Returns shape of the parameter
Returns: Shape of the parameter Return type: tuple
-
zero
()¶ Set the parameter to zero
-
-
class
dynet.
LookupParameters
¶ LookupParameters represents a table of parameters.
They are used to embed a set of discrete objects (e.g. word embeddings). These are sparsely updated.
-
as_array
()¶ Return as a numpy array.
The first dimension is the lookup dimension
Returns: Values Return type: np.array
-
batch
(i)¶ Returns a batched expression based on looked up indices
This does the same as
dynet.lookup_batch
Parameters: i (list) – list of indices Returns: Batched expression fo batch dimension len(i)
Return type: dynet.Expression
-
expr
(update=True)¶ Returns an expression for the whole parameter
Same as
dynet.parameter
Parameters: update (bool) – If this is set to False, the parameter won’t be updated during the backward pass Returns: Expression of the parameter Return type: Expression
-
grad_as_array
()¶ Return gradients as a numpy array.
The first dimension is the lookup dimension
Returns: gradient values Return type: np.array
-
init_from_array
(arr)¶ Initializes the values according to a numpy array
Preferably uses ParameterCollection.lookup_parameter_from_numpy when possible
Parameters: arr (np.array) – numpy array of shape (num_lookups,...)
-
init_row
(i, row)¶ Initialize one row with values
Parameters: - i (int) – index
- row (list) – values
-
name
()¶ Return the full name of this lookup parameter.
-
populate
(fname, key='')¶ Populate the values of this LookupParameters object from the parameter named key in the file fname. The sizes of saved parameters and this object must match.
Parameters: - fname (string) – the name of a file to load from.
- key (string) – the parameter to read from the file.
-
row_as_array
(row)¶ Return row as a numpy array.
Parameters: row (int) – row to return Returns: Values Return type: np.array
-
row_grad_as_array
(row)¶ Return row gradient as a numpy array.
Parameters: row (int) – row to return Returns: Values Return type: np.array
-
rows_as_array
(rows)¶ Return rows as a numpy array.
The first dimension is the lookup dimension
Parameters: rows (list) – rows to return Returns: Values Return type: np.array
-
rows_grad_as_array
(rows)¶ Return rows gradients as a numpy array.
The first dimension is the lookup dimension
Parameters: rows (list) – rows to return Returns: Values Return type: np.array
-
save
(fname, key='', append=False)¶ Save the values of this LookupParameters object to a particular file.
TODO: more docs. Refer to the tutorial for more info for now
Parameters: - fname (string) – the name of a file to save to.
- key (string) – TODO
-
scale
(s)¶ Scales the parameter
Parameters: s (float) – Scale
-
scale_gradient
(s)¶ Scales the gradient
Parameters: s (float) – Scale
-
shape
()¶ Returns shape of the lookup parameter
The first dimension is the lookup dimension
Returns: Shape of the parameter Return type: tuple
-
zero
()¶ Set all values to zero
-
Parameters initializers¶
-
class
dynet.
PyInitializer
¶ Base class for parameter initializer
-
class
dynet.
NormalInitializer
(mean=0, var=1)¶ Bases:
dynet.PyInitializer
Initialize the parameters with a gaussian distribution
Keyword Arguments: - mean (number) – Mean of the distribution (default: 0)
- var (number) – Variance of the distribution (default: 1)
-
class
dynet.
UniformInitializer
(scale)¶ Bases:
dynet.PyInitializer
Initialize the parameters with a uniform distribution
Parameters: scale (number) – Parmeters are sampled from \(\mathcal U([-\texttt{scale},\texttt{scale}])\)
-
class
dynet.
ConstInitializer
(c)¶ Bases:
dynet.PyInitializer
Initialize the parameters with a constant value
Parameters: c (number) – Value to initialize the parameters
-
class
dynet.
IdentityInitializer
¶ Bases:
dynet.PyInitializer
Initialize the parameters as the identity
Only works with square matrices
-
class
dynet.
GlorotInitializer
(is_lookup=False, gain=1.0)¶ Bases:
dynet.PyInitializer
Initializes the weights according to Glorot & Bengio (2011)
If the dimensions of the parameter matrix are \(m,n\), the weights are sampled from \(\mathcal U([-g\sqrt{\frac{6}{m+n}},g\sqrt{\frac{6}{m+n}}])\)
The gain \(g\) depends on the activation function :
- \(\text{tanh}\) : 1.0
- \(\text{ReLU}\) : 0.5
- \(\text{sigmoid}\) : 4.0
- Any smooth function \(f\) : \(\frac{1}{f'(0)}\)
Note: This is also known as Xavier initialization
Keyword Arguments: - is_lookup (bool) – Whether the parameter is alookup parameter (default: False)
- gain (number) – Gain (Depends on the activation function) (default: 1.0)
-
class
dynet.
SaxeInitializer
(scale=1.0)¶ Bases:
dynet.PyInitializer
Initializes according to Saxe et al. (2014)
- Initializes as a random orthonormal matrix (unimplemented for GPU)
- Keyword Arguments:
- scale (number): scale to apply to the orthonormal matrix
-
class
dynet.
FromFileInitializer
(fname)¶ Bases:
dynet.PyInitializer
Initialize parameter from file
Parameters: fname (str) – File name
-
class
dynet.
NumpyInitializer
(array)¶ Bases:
dynet.PyInitializer
Initialize from numpy array
Alternatively, use
ParameterCollection.parameters_from_numpy()
Parameters: array (np.ndarray) – Numpy array
High level saving/loading¶
-
dynet.
save
(basename, lst)¶ Saves a list of parameters, lookup parameters and builder objects to disk.
Parameters: - basename (string) – The base-name of the files to save. Two files will be created: basename.data and basename.meta.
- lst (list) – A list of objects to save (see below).
Example
import dynet as dy
pc = dy.ParameterCollection() W = pc.add_parameters((100,50)) E = pc.add_lookup_parameters((1000,50)) builder = dy.LSTMBuilder(2, 50, 50, pc)
dy.save(“model”, [E, builder, W])
# then, when loading: pc = dy.ParameterCollection() E2, builder2, W2 = dy.load(“model”, pc)
- What can be saved:
Each object in lst must be one of the following:
- Parameter
- LookupParameter
- one of the built-in types (CompactVanillaLSTMBuilder, VanillaLSTMBuilder, LSTMBuilder, GRUBuilder,
- SimpleRNNBuilder, BiRNNBuilder)
- a type adhering to the following interface:
- has a param_collection() method returning a ParameterCollection object with the parameters in the object.
- has a .spec property with picklable items describing the object
- has a .from_spec(spec, model) static method that will create and return a new instane with the needed parameters/etc in the model.
Note, the built-in types in (3) above can be saved/loaded this way simply because they support this interface.
behind the scenes:
- for each item, we write to .meta:
- if its a Parameters/ParameterCollection:
- its type and full name.
- if its a builder:
- its class, its spec, the full name of its parameters collection.
- the associated parameters/sub-collection is then saved to .data
-
dynet.
load
(basename, params)¶ Loads a list of parameters, lookup parameters and builder objects from disk. The loaded objects are added to the supplied params collection, and returned.
Parameters: - basename (string) – The basename to read from. This is the same string that was used when saving the objects.
- params (dynet.ParameterCollection) – A ParameterCollection to add the loaded objects to.
Returns: A list of parameters, lookup parameters and builder objects, in the same order they were passed to the save function.
Example
import dynet as dy
pc = dy.ParameterCollection() W = pc.add_parameters((100,50)) E = pc.add_lookup_parameters((1000,50)) builder = dy.LSTMBuilder(2, 50, 50, pc)
dy.save(“model”, [E, builder, W])
# then, when loading: pc = dy.ParameterCollection() E2, builder2, W2 = dy.load(“model”, pc)
Computation Graph¶
-
dynet.
renew_cg
(immediate_compute=False, check_validity=False, autobatching=None)¶ Renew the computation graph.
Call this before building any new computation graph
-
dynet.
cg_version
()¶ Varsion of the current computation graph
-
dynet.
print_text_graphviz
()¶
-
dynet.
cg_checkpoint
()¶ Saves the state of the computation graph
-
dynet.
cg_revert
()¶ Revert the computation graph state to the previous checkpoint
-
dynet.
cg
()¶ Get the current ComputationGraph
-
class
dynet.
ComputationGraph
¶ Computation graph object
While the ComputationGraph is central to the inner workings of DyNet, from the user’s perspective, the only responsibility is to create a new computation graph for each training example.
-
parameters
(params)¶ Same as
dynet.parameters(params)
-
renew
(immediate_compute=False, check_validity=False, autobatching=None)¶ Same as
dynet.renew_cg()
-
version
()¶ Same as
dynet.cg_version()
-
Operations¶
Expressions¶
-
class
dynet.
Expression
¶ Expressions are the building block of a Dynet computation graph.
Expressions are the main data types being manipulated in a DyNet program. Each expression represents a sub-computation in a computation graph.
-
backward
(full=False)¶ Run the backward pass based on this expression
The parameter
full
specifies whether the gradients should be computed for all nodes (True
) or only non-constant nodes (False
).By default, a node is constant unless
- it is a parameter node
- it depends on a non-constant node
Thus, functions of constants and inputs are considered as constants.
Turn
full
on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.Parameters: full (bool) – Whether to compute all gradients (including with respect to constant nodes).
-
dim
()¶ Dimension of the expression
Returns a tuple (dims,batch_dim) where dims is the tuple of dimensions of each batch element
Returns: dimension Return type: tuple
-
forward
(recalculate=False)¶ This runs incremental forward on the entire graph
May not be optimal in terms of efficiency. Prefer
values
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False)
-
gradient
()¶ Returns the value of the expression as a numpy array
The last dimension is the batch size (if it’s > 1).
Make sure to call
backward
on a downstream expression before calling this.If the Expression is a constant expression (meaning it’s not a function of a parameter), dynet won’t compute it’s gradient for the sake of efficiency. You need to manually force the gradient computation by adding the agument
full=True
tobackward
Returns: numpy array of values Return type: np.ndarray
-
npvalue
(recalculate=False)¶ Returns the value of the expression as a numpy array
The last dimension is the batch size (if it’s > 1)
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: numpy array of values Return type: np.ndarray
-
scalar_value
(recalculate=False)¶ Returns value of an expression as a scalar
This only works if the expression is a scalar
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: Scalar value of the expression Return type: float
-
tensor_value
(recalculate=False)¶ Returns the value of the expression as a Tensor.
This is useful if you want to use the value for other on-device calculations that are not part of the computation graph, i.e. using argmax.
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: a dynet Tensor object. Return type: Tensor
-
value
(recalculate=False)¶ Gets the value of the expression in the most relevant format
this returns the same thing as
scalar_value
,vec_value
,npvalue
depending on whether the number of dimensions of the expression is 0, 1 or 2+Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: Value of the expression Return type: float, list, np.ndarray
-
vec_value
(recalculate=False)¶ Returns the value of the expression as a vector
In case of a multidimensional expression, the values are flattened according to a column major ordering
Keyword Arguments: recalculate (bool) – Recalculate the computation graph (for static graphs with new inputs) (default: False) Returns: Array of values Return type: list
-
Operations¶
Operations are used to build expressions
Input operations¶
-
dynet.
inputTensor
(arr, batched=False, device='')¶ Creates a tensor expression based on a numpy array or a list.
The dimension is inferred from the shape of the input. if batched=True, the last dimension is used as a batch dimension if arr is a list of numpy ndarrays, this returns a batched expression where the batch elements are the elements of the list
Parameters: arr (list,np.ndarray) – Values : numpy ndarray OR list of np.ndarray OR multidimensional list of floats
Keyword Arguments: - batched (bool) – Whether to use the last dimension as a batch dimension (default: False)
- device (string) – Optional, device on which to create the expression.
Returns: Input expression
Return type: _vecInputExpression
Raises: TypeError
– If the type is not respected
-
dynet.
sparse_inputTensor
(idxs, values, shape, batched=False, defval=0, device='')¶ Creates a tensor expression based on indices and values
The dimension is inferred from the shape of the input. if batched=True, the last dimension is used as a batch dimension if arr is a list of numpy ndarrays, this returns a batched expression where the batch elements are the elements of the list
Parameters: - idxs (tuple, list) – A tuple/list of integer arrays, one array for each dimension (including the batch dimension)
- values (list,np.ndarray) – A 1D array/list of values
- shape – The desired shape
Keyword Arguments: - batched (bool) – Whether to use the last dimension as a batch dimension (default: False). For example if
shape=(3, 3, 3)
andbatched=True
the resulting expression will be a batch of 3 3x3 matrices - defval (number) – The default value for all non specified coordinates (default: 0)
- device (string) – Optional, device on which to create the expression.
Returns: Input expression
Return type: _vecInputExpression
Raises: TypeError
– If the type is not respectedValueError
– If the number of dimensions don’t match
-
dynet.
parameter
(*args)¶ Add parameters to the computation graph.
Get the expression objects corresponding to parameters. Gradients for parameters will be computed and used by Optimizers to update.
Parameters: - args – Parameter and LookupParameter objects to add to the computation
- graph. –
Returns: one expression for each input parameter.
Return type: Raises: NotImplementedError
– Only works with Parameters and LookupParameters.
-
dynet.
const_parameter
(*args)¶ Add constant parameters to the computation graph.
Get the expression objects corresponding to parameters. Gradients for parameters will be NOT computed or used by Optimizers to update. To access parameters that should be updated (which is usually what you want), use parameter() instead.
Parameters: - args – Parameter and LookupParameter objects to add to the computation
- graph. –
Returns: one expression for each input parameter.
Return type: Raises: NotImplementedError
– Only works with Parameters and LookupParameters.
-
dynet.
scalarInput
(s, device='')¶
-
dynet.
vecInput
(dim, device='')¶ Input an empty vector
Parameters: - dim (number) – Size
- device (string) – Optional, device on which to create the expression.
Returns: Corresponding expression
Return type: _vecInputExpression
-
dynet.
inputVector
(v, device='')¶ Input a vector by values
Parameters: - v (vector[float]) – Values
- device (string) – Optional, device on which to create the expression.
Returns: Corresponding expression
Return type: _vecInputExpression
-
dynet.
matInput
(d1, d2)¶ DEPRECATED : use inputTensor
TODO : remove this
Parameters: - d1 (int) – [description]
- d2 (int) – [description]
Returns: [description]
Return type:
-
dynet.
inputMatrix
(v, d)¶ DEPRECATED : use inputTensor
TODO : remove this
inputMatrix(vector[float] v, tuple d)
Create a matrix literal. First argument is a list of floats (or a flat numpy array). Second argument is a dimension. Returns: an expression. Usage example:
x = inputMatrix([1,2,3,4,5,6],(2,3)) x.npvalue() --> array([[ 1., 3., 5.], [ 2., 4., 6.]])
-
dynet.
lookup
(p, index=0, update=True)¶ Pick an embedding from a lookup parameter and returns it as a expression
param p: Lookup parameter to pick from type p: LookupParameters Keyword Arguments: - index (number) – Lookup index (default: 0)
- update (bool) – Whether to update the lookup parameter [(default: True)
Returns: Expression for the embedding
Return type: _lookupExpression
-
dynet.
lookup_batch
(p, indices, update=True)¶ Look up parameters.
The mini-batched version of lookup. The resulting expression will be a mini-batch of parameters, where the “i”th element of the batch corresponds to the parameters at the position specified by the “i”th element of “indices”
Parameters: - p (LookupParameters) – Lookup parameter to pick from
- indices (list(int)) – Indices to look up for each batch element
Keyword Arguments: update (bool) – Whether to update the lookup parameter (default: True)
Returns: Expression for the batched embeddings
Return type: _lookupBatchExpression
-
dynet.
zeros
(dim, batch_size=1)¶ Create an input full of zeros
Create an input full of zeros, sized according to dimensions
dim
Parameters: dim (tuple, int) – Dimension of the tensor Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1)) Returns: A d
dimensioned zero tensorReturn type: dynet.Expression
-
dynet.
ones
(dim, batch_size=1)¶ Create an input full of ones
Create an input full of ones, sized according to dimensions
dim
Parameters: dim (tuple, int) – Dimension of the tensor Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1)) Returns: A d
dimensioned zero tensorReturn type: dynet.Expression
-
dynet.
constant
(dim, val, batch_size=1)¶ Create an input full of
val
Create an input full of
val
, sized according to dimensionsdim
Parameters: - dim (tuple, int) – Dimension of the tensor
- val (number) – Value
Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1))
Returns: A
d
dimensioned tensor filled with valueval
Return type:
-
dynet.
random_normal
(dim, batch_size=1)¶ Create a random normal vector
Create a vector distributed according to normal distribution with mean 0, variance 1.
Parameters: dim (tuple, int) – Dimension of the tensor Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1)) Returns: A “d” dimensioned normally distributed tensor Return type: dynet.Expression
-
dynet.
random_bernoulli
(dim, p, scale=1.0, batch_size=1)¶ Create a random bernoulli tensor
Create a tensor distributed according to bernoulli distribution with parameter \(p\).
Parameters: - dim (tuple, int) – Dimension of the tensor
- p (number) – Parameter of the bernoulli distribution
Keyword Arguments: - scale (number) – Scaling factor to apply to the sampled tensor (default: (1.0))
- batch_size (number) – Batch size of the tensor (default: (1))
Returns: A “d” dimensioned bernoulli distributed tensor
Return type:
-
dynet.
random_uniform
(dim, left, right, batch_size=1)¶ Create a random uniform tensor
Create a tensor distributed according to uniform distribution with boundaries left and right.
Parameters: - dim (tuple, int) – Dimension of the tensor
- left (number) – Lower bound of the uniform distribution
- right (number) – Upper bound of the uniform distribution
Keyword Arguments: batch_size (number) – Batch size of the tensor (default: (1))
Returns: A “d” dimensioned uniform distributed tensor
Return type:
-
dynet.
random_gumbel
(dim, mu=0.0, beta=1.0, batch_size=1)¶ Create a random Gumbel sampled vector
Create a vector distributed according to a Gumbel distribution with the specified parameters. (Currently only the defaults of mu=0.0 and beta=1.0 supported.
Parameters: dim (tuple, int) – Dimension of the tensor
Keyword Arguments: - mu (number) – The \(\mu\) parameter (default: (0.0))
- beta (number) – The \(\beta\) parameter (default: (1.0))
- batch_size (number) – Batch size of the tensor (default: (1))
Returns: “d” dimensioned Gumbel distributed tensor
Return type:
-
dynet.
noise
(x, stddev)¶ Additive gaussian noise
Add gaussian noise to an expression.
Parameters: - x (dynet.Expression) – Input expression
- stddev (number) – The standard deviation of the gaussian
Returns: \(y\sim\mathcal N(x,\texttt{stddev})\)
Return type:
Arithmetic operations¶
-
dynet.
cdiv
(x, y)¶ Componentwise division
- Divide an expressions component-wise by another, broadcasting dimensions (currently only of the second expression!) if necessary as follows:
- When number of dimensions differ, we add dimensions of size 1 to make the number of dimensions match
- Now, every dimensions is required to have matching size, or the dim size of the right expression must equal 1 (in which case it will be broadcasted)
- In the same way, the batch sizes must match, or the batch size of the right expression must equal 1 in which case it will be broadcasted
- The resulting tensor’s dimensionality is thus determined as the max of both inputs at every position
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: An expression where the ith element is equal to \(\frac{x_i}{y_i}\)
Return type:
-
dynet.
cmult
(x, y)¶ Componentwise multiplication
- Multiply two expressions component-wise, broadcasting dimensions if necessary as follows:
- When number of dimensions differ, we add dimensions of size 1 to make the number of dimensions match
- Now, every dimensions is required to have matching size, or one of the dimensions must equal 1 (in which case it will be broadcasted)
- In the same way, the batch dimension must match, or equal 1 in which case it will be broadcasted
- The resulting tensor’s dimensionality is thus determined as the max of both inputs at every position
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: An expression where the ith element is equal to \(x_i\times y_i\)
Return type:
-
dynet.
colwise_add
(x, y)¶ Columnwise addition
Add vector \(y\) to each column of matrix \(x\)
Parameters: - x (dynet.Expression) – An MxN matrix
- y (dynet.Expression) – A length M vector
Returns: An expression where \(y\) is added to each column of \(x\)
Return type:
-
dynet.
squared_norm
(x)¶ Squared norm
The squared norm of the values of
x
: \(\Vert x\Vert_2^2=\sum_i x_i^2\).Parameters: x (dynet.Expression) – Input expression Returns: \(\Vert x\Vert_2^2=\sum_i x_i^2\) Return type: dynet.Expression
-
dynet.
l2_norm
(x)¶ L2 norm
The l2 norm of the values of
x
: \(\Vert x\Vert_2=\sqrt{\sum_i x_i^2}\).Parameters: x (dynet.Expression) – Input expression Returns: \(\Vert x\Vert_2=\sqrt{\sum_i x_i^2}\) Return type: dynet.Expression
-
dynet.
exp
(x)¶ Natural exponent
Calculate elementwise \(y_i = e^{x_i}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(e^{x}\) Return type: dynet.Expression
-
dynet.
square
(x)¶ Square
Calculate elementwise \(y_i = x_i^2\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = x^2\) Return type: dynet.Expression
-
dynet.
sqrt
(x)¶ Square root
Calculate elementwise \(y_i = \sqrt{x_i}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = \sqrt{x}\) Return type: dynet.Expression
-
dynet.
abs
(x)¶ Absolute value
Calculate elementwise \(y_i = \vert x_i\vert\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = \vert x\vert\) Return type: dynet.Expression
-
dynet.
erf
(x)¶ Gaussian error function
Elementwise calculation of the Gaussian error function \(y_i = \text{erf}(x_i)=\frac {1}{\sqrt{\pi}}\int_{-x_i}^{x_i}e^{-t^2}\mathrm{d}t\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \text{erf}(x_i)\) Return type: dynet.Expression
-
dynet.
cube
(x)¶ Calculate elementwise \(y_i = x_i^3\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y = x^3\) Return type: dynet.Expression
-
dynet.
log
(x)¶ Natural logarithm
Elementwise calculation of the natural logarithm \(y_i = \ln(x_i)\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \ln(x_i)\) Return type: dynet.Expression
-
dynet.
lgamma
(x)¶ Log gamma
Calculate elementwise log gamma function \(y_i = \ln(\Gamma(x_i))\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \ln(\Gamma(x_i))\) Return type: dynet.Expression
-
dynet.
sin
(x)¶ Sine
Elementwise calculation of the sine
Parameters: x (dynet.Expression) – Input expression Returns: \(\sin(x)\) Return type: dynet.Expression
-
dynet.
cos
(x)¶ Cosine
Elementwise calculation of the cosine
Parameters: x (dynet.Expression) – Input expression Returns: \(\cos(x)\) Return type: dynet.Expression
-
dynet.
tan
(x)¶ Tangent
Elementwise calculation of the tangent
Parameters: x (dynet.Expression) – Input expression Returns: \(\tan(x)\) Return type: dynet.Expression
-
dynet.
asin
(x)¶ Inverse sine
Elementwise calculation of the inverse sine
Parameters: x (dynet.Expression) – Input expression Returns: \(\sin^{-1}(x)\) Return type: dynet.Expression
-
dynet.
acos
(x)¶ Inverse cosine
Elementwise calculation of the inverse cosine
Parameters: x (dynet.Expression) – Input expression Returns: \(\cos^{-1}(x)\) Return type: dynet.Expression
-
dynet.
atan
(x)¶ Tangent
Elementwise calculation of the inverse tangent
Parameters: x (dynet.Expression) – Input expression Returns: \(\tan^{-1}(x)\) Return type: dynet.Expression
-
dynet.
sinh
(x)¶ Hyperbolic sine
Elementwise calculation of the hyperbolic sine
Parameters: x (dynet.Expression) – Input expression Returns: \(\sinh(x)\) Return type: dynet.Expression
-
dynet.
cosh
(x)¶ Hyperbolic cosine
Elementwise calculation of the hyperbolic cosine
Parameters: x (dynet.Expression) – Input expression Returns: \(\cosh(x)\) Return type: dynet.Expression
-
dynet.
tanh
(x)¶ Hyperbolic tangent
Elementwise calculation of the hyperbolic tangent
Parameters: x (dynet.Expression) – Input expression Returns: \(\tanh(x)\) Return type: dynet.Expression
-
dynet.
asinh
(x)¶ Inverse hyperbolic sine
Elementwise calculation of the inverse hyperbolic sine
Parameters: x (dynet.Expression) – Input expression Returns: \(\sinh^{-1}(x)\) Return type: dynet.Expression
-
dynet.
acosh
(x)¶ Inverse hyperbolic cosine
Elementwise calculation of the inverse hyperbolic cosine
Parameters: x (dynet.Expression) – Input expression Returns: \(\cosh^{-1}(x)\) Return type: dynet.Expression
-
dynet.
atanh
(x)¶ Inverse hyperbolic tangent
Elementwise calculation of the inverse hyperbolic tangent
Parameters: x (dynet.Expression) – Input expression Returns: \(\tanh^{-1}(x)\) Return type: dynet.Expression
-
dynet.
logistic
(x)¶ Logistic sigmoid function
Calculate elementwise \(y_i = \frac{1}{1+e^{-x_i}}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \frac{1}{1+e^{-x_i}}\) Return type: dynet.Expression
-
dynet.
rectify
(x)¶ Rectifier (or ReLU, Rectified Linear Unit)
Calculate elementwise recitifer (ReLU) function \(y_i = \max(x_i,0)\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \max(x_i,0)\) Return type: dynet.Expression
-
dynet.
elu
(x, alpha=1.0)¶ Exponential Linear Unit (ELU)
Calculate elementwise the function
\[\begin{split}y_i = \left\{\begin{array}{lr} x_i, & \text{if } x>0\\ \alpha\times(e^{x_i} - 1), & \text{if }x\leqslant 0 \end{array}\right.\end{split}\]Reference: Clevert et al., 2015
Parameters: - x (dynet.Expression) – Input expression
- alpha (number) – \(\alpha\) parameter
Returns: \(\text{ELU}(x_i, \alpha)\)
Return type:
-
dynet.
selu
(x)¶ Scaled Exponential Linear Unit (SELU)
Calculate elementwise the function
\[\begin{split}y_i = \lambda\times\left\{ \begin{array}{lr} x_i, & \text{if } x>0\\ \alpha\times(e^{x_i} - 1), & \text{if }x\leqslant 0\\ \end{array}\right.\end{split}\]With
\[\begin{split}\begin{split} \lambda &=\texttt{1.0507009873554804934193349852946}\\ \alpha &=\texttt{1.6732632423543772848170429916717}\\ \end{split}\end{split}\]Reference: Klambaouer et al., 2017
Parameters: x (dynet.Expression) – Input expression Returns: \(\text{SELU}(x_i)\) Return type: dynet.Expression
-
dynet.
sparsemax
(x)¶ Sparsemax
The sparsemax function (Martins et al. 2016), which is similar to softmax, but induces sparse solutions where most of the vector elements are zero. Note: This function is not yet implemented on GPU.
Parameters: x (dynet.Expression) – Input expression Returns: The sparsemax of the scores Return type: dynet.Expression
-
dynet.
softsign
(x)¶ Softsign function
Calculate elementwise the softsign function \(y_i = \frac{x_i}{1+\vert x_i\vert}\)
Parameters: x (dynet.Expression) – Input expression Returns: \(y_i = \frac{x_i}{1+\vert x_i\vert}\) Return type: dynet.Expression
-
dynet.
pow
(x, y)¶ Power function
Calculate an output where the ith element is equal to \(x_i^{y_i}\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(x_i^{y_i}\)
Return type:
-
dynet.
bmin
(x, y)¶ Minimum
Calculate an output where the ith element is \(\min(x_i,y_i)\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\min(x_i,y_i)\)
Return type:
-
dynet.
bmax
(x, y)¶ Maximum
Calculate an output where the ith element is \(\max(x_i,y_i)\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\max(x_i,y_i)\)
Return type:
Reduction/moment operations¶
-
dynet.
sum_elems
(x)¶ Sum all elements
Sum all the elements in an expression.
Parameters: x (dynet.Expression) – Input expression Returns: The sum of all of its elements Return type: dynet.Expression
-
dynet.
moment_elems
(x, r)¶ Statistical moment of elements of the tensor
Computes the statistical moment of order \(r\), \(\frac 1 n \sum_ix_i^r\) of all the elements of each minibatch. :param x: Input expression :type x: dynet.Expression :param r: Moment order :type r: int
Returns: A scalar expression (minibatched) Return type: dynet.Expression
-
dynet.
mean_elems
(x)¶ Mean of elements of the tensor
Computes the mean \(\frac 1 n \sum_ix_i\) of all the elements of each minibatch. :param x: Input expression :type x: dynet.Expression
Returns: A scalar expression (minibatched) Return type: dynet.Expression
-
dynet.
std_elems
(x)¶ Standard deviation of elements of the tensor
Computes the standard deviation \(\sigma=\sqrt{\frac 1 n \sum_i(x_i-\mu)^2}\) of all the elements of each minibatch. :param x: Input expression :type x: dynet.Expression
Returns: A scalar expression (minibatched) Return type: dynet.Expression
-
dynet.
sum_dim
(x, d, b=False, n=0)¶ Mean along an arbitrary dimension
Computes the sum \(\sum_ix_i\) along an arbitrary dimension or dimensions.
Parameters: - x (dynet.Expression) – Input expression
- d (list) – Dimensions along which to reduce
- b (bool) – Whether to include batch dimension
Returns: An expression with |d| less dimensions and possibly dropped batch dimension
Return type:
-
dynet.
moment_dim
(x, d, r, b, n=0)¶ Statistical moment along an arbitrary dimension
Computes the statistical moment of order \(r\), \(\frac 1 n \sum_ix_i^r\) along an arbitrary dimension.
Parameters: - x (dynet.Expression) – Input expression
- d (list) – Dimensions along which to reduce
- r (int) – Moment order
- b (bool) – Whether to include batch dimension
- n (int) – If > 0, overwrite the n in the equation by this value, useful for masking
Returns: An expression with |d| less dimensions and possibly dropped batch dimension
Return type:
-
dynet.
mean_dim
(x, d, b, n=0)¶ Mean along an arbitrary dimension
Computes the mean \(\frac 1 n \sum_ix_i\) along an arbitrary dimension.
Parameters: - x (dynet.Expression) – Input expression
- d (list) – Dimensions along which to reduce
- b (bool) – Whether to include batch dimension
- n (int) – If > 0, overwrite the n in the equation by this value, useful for masking
Returns: An expression with |d| less dimensions and possibly dropped batch dimension
Return type:
-
dynet.
std_dim
(x, d, b, n=0)¶ Standard deviation along an arbitrary dimension
Computes the standard deviation \(\sigma=\sqrt{\frac 1 n \sum_i(x_i-\mu)^2}\) along arbitrary dimensions.
Parameters: - x (dynet.Expression) – Input expression
- d (int) – Dimensions along which to reduce
- b (bool) – Whether to include batch dimension
- n (int) – If > 0, overwrite the n in the equation by this value, useful for masking
Returns: An expression with |d| less dimensions and possibly dropped batch dimension
Return type:
-
dynet.
max_dim
(x, d=0)¶ Max out through a dimension
Select out a element/row/column/sub-tensor from an expression, with maximum value along a given dimension. This will result in the dimension of the expression being reduced by 1.
Parameters: x (dynet.Expression) – Input expression Keyword Arguments: d (int) – Dimension on which to perform the maxout (default: (0)) Returns: An expression of sub-tensor with max value along dimension d
Return type: dynet.Expression
-
dynet.
min_dim
(x, d=0)¶ Min out through a dimension
Select out a element/row/column/sub-tensor from an expression, with minimum value along a given dimension. This will result in the dimension of the expression being reduced by 1.
Parameters: x (dynet.Expression) – Input expression Keyword Arguments: d (int) – Dimension on which to perform the minout (default: (0)) Returns: An expression of sub-tensor with min value along dimension d
Return type: dynet.Expression
-
dynet.
sum_batches
(x)¶ Sum over minibatches
Sum an expression that consists of multiple minibatches into one of equal dimension but with only a single minibatch. This is useful for summing loss functions at the end of minibatch training.
Parameters: x (dynet.Expression) – Input expression Returns: An expression with a single batch Return type: dynet.Expression
-
dynet.
moment_batches
(x, r)¶ Statistical moment along the batch dimension
Computes the statistical moment of order \(r\), \(\frac 1 n \sum_ix_i^r\) along the batch dimension. :param x: Input expression :type x: dynet.Expression :param r: Moment order :type r: int
Returns: An expression with a single batch Return type: dynet.Expression
-
dynet.
mean_batches
(x)¶ Mean along the batch dimension
Computes the mean \(\frac 1 n \sum_ix_i\) along the batch dimension. :param x: Input expression :type x: dynet.Expression
Returns: An expression with a single batch Return type: dynet.Expression
-
dynet.
std_batches
(x)¶ Standard deviation along the batch dimension
Computes the standard deviation \(\sigma=\sqrt{\frac 1 n \sum_i(x_i-\mu)^2}\) along the batch dimension. :param x: Input expression :type x: dynet.Expression
Returns: An expression with a single batch Return type: dynet.Expression
-
dynet.
fold_rows
(x, nrows=2)¶ [summary]
[description]
Parameters: x (dynet.Expression) – Keyword Arguments: nrows {number} (unsigned) – (default: (2)) Returns: Return type: dynet.Expression
-
dynet.
esum
(xs)¶ Sum
This performs an elementwise sum over all the expressions in
xs
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\sum_{j=0}\texttt{xs[}j\texttt{][}i\texttt{]}\) Return type: dynet.Expression
-
dynet.
emax
(xs)¶ Max
This performs an elementwise max over all the expressions in
xs
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\max_j\texttt{xs[}j\texttt{][}i\texttt{]}\) Return type: dynet.Expression
-
dynet.
logsumexp
(xs)¶ Log, sum, exp
The elementwise “logsumexp” function that calculates \(\ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\ln\left(\sum_{j=0}e^{\texttt{xs[}j\texttt{][}i\texttt{]}}\right)\) Return type: dynet.Expression
-
dynet.
average
(xs)¶ Average
This performs an elementwise average over all the expressions in
xs
Parameters: xs (list) – A list of expression of same dimension Returns: An expression where the ith element is equal to \(\frac{1}{\texttt{len(xs)}}\sum_{j=0}\texttt{xs[}j\texttt{][}i\texttt{]}\) Return type: dynet.Expression
Loss/Probability operations¶
-
dynet.
softmax
(x, d=0)¶ Softmax
The softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the \(\frac{e^{x_i}}{sum_j e^{x_j}}\).
Parameters: - x (dynet.Expression) – Input expression
- d (int) – Dimension to normalize over
Returns: \(\frac{e^{x_i}}{\sum_j e^{x_j}}\)
Return type:
-
dynet.
log_softmax
(x, restrict=None)¶ Restricted log softmax
The log softmax function calculated over only a subset of the vector elements. The elements to be included are set by the
restriction
variable. All elements not included inrestriction
are set to negative infinity.Parameters: x (dynet.Expression) – Input expression Keyword Arguments: restrict (list) – List of log softmax to compute (default: (None)) Returns: A vector with the log softmax over the specified elements Return type: dynet.Expression
-
dynet.
pairwise_rank_loss
(x, y, m=1.0)¶ Pairwise rank loss
A margin-based loss, where every margin violation for each pair of values is penalized: \(\sum_i \max(m - x_i + y_i, 0)\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Keyword Arguments: m (number) – The margin (default: (1.0))
Returns: The pairwise rank loss
Return type:
-
dynet.
poisson_loss
(log_lambda, x)¶ Poisson loss
The negative log probability of
x
according to a Poisson distribution with parameter \(\exp\)log_lambda
. Useful in Poisson regression where, we try to predict the parameters of a Possion distribution to maximize the probability of datax
.Parameters: - log_lambda (dynet.Expression) – The log of the Poisson distribution’s lambda
- x (int) – The target value
Returns: The Poisson loss
Return type:
-
dynet.
huber_distance
(x, y, c=1.345)¶ Huber distance
The huber distance between values of
x
andy
parameterized byc
, \(\sum_i L_c(x_i, y_i)\) where:\[\begin{split}L_c(x, y) = \begin{cases}{lr} \frac{1}{2}(y - x)^2 & \textrm{for } \vert y - f(x)\vert \le c, \\ c\, \vert y - f(x)\vert - \frac{1}{2}c^2 & \textrm{otherwise.} \end{cases}\end{split}\]Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Keyword Arguments: c (number) – The parameter of the huber distance parameterizing the cuttoff (default: (1.345))
Returns: The huber distance
Return type:
-
dynet.
pickneglogsoftmax
(x, v)¶ Negative softmax log likelihood
This function takes in a vector of scores
x
, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the elementv
. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.Parameters: - x (dynet.Expression) – Input scores
- v (int) – True class
Returns: \(-\log\left(\frac{e^{x_v}}{\sum_j e^{x_j}}\right)\)
Return type:
-
dynet.
pickneglogsoftmax_batch
(x, vs)¶ Negative softmax log likelihood on a batch
This function takes in a batched vector of scores
x
, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the elementsvs
. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.Parameters: - x (dynet.Expression) – Input scores
- v (list) – True classes
Returns: \(-\sum_{v\in \texttt{vs}}\log\left(\frac{e^{x_v}}{\sum_j e^{x_j}}\right)\)
Return type:
-
dynet.
hinge
(x, v, m=1.0)¶ Hinge loss
This function takes in a vector of scores
x
, and calculates a hinge loss such that the elementv
must be greater than all other elements by at leastm
, otherwise a loss is incurred.Parameters: - x (dynet.Expression) – Input scores
- v (int) – True class
- m (float) – The margin
Returns: \(\sum_{\tilde{v} != v} max(x_{\tilde{v}} - x_v + m, 0)\)
Return type:
-
dynet.
hinge_batch
(x, vs, m=1.0)¶ Hinge loss on a batch
This function takes in a batched vector of scores
xs
, and calculates a hinge loss such that the elementsvs
must be greater than all other elements by at leastm
, otherwise a loss is incurred.Parameters: - x (dynet.Expression) – Input scores
- v (list) – True classes
- m (float) – The margin
Returns: The batched hinge loss function
Return type:
-
dynet.
kmh_ngram
(x, v)¶ [summary]
[description]
Parameters: - x (dynet.Expression) –
- v (dynet.Expression) –
Returns: Return type:
-
dynet.
squared_distance
(x, y)¶ Squared distance
The squared distance between values of
x
andy
: \(\Vert x-y\Vert_2^2=\sum_i (x_i-y_i)^2\).Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\Vert x-y\Vert_2^2=\sum_i (x_i-y_i)^2\)
Return type:
-
dynet.
l1_distance
(x, y)¶ L1 distance
L1 distance between values of
x
andy
: \(\Vert x-y\Vert_1=\sum_i \vert x_i-y_i\vert\).Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(\Vert x-y\Vert_1=\sum_i \vert x_i-y_i\vert\).
Return type:
-
dynet.
binary_log_loss
(x, y)¶ Binary log loss
The log loss of a binary decision according to the sigmoid sigmoid function \(- \sum_i (y_i \ln(x_i) + (1-y_i) \ln(1-x_i))\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(- \sum_i (y_i \ln(x_i) + (1-y_i) \ln(1-x_i))\)
Return type:
Flow/Shaping operations¶
-
dynet.
pick
(e, index=0, dim=0)¶ Pick element.
Pick a single element/row/column/sub-tensor from an expression. This will result in the dimension of the tensor being reduced by 1.
Parameters: e (Expression) – Expression to pick from
Keyword Arguments: - index (number) – Index to pick (default: 0)
- dim (number) – Dimension to pick from (default: 0)
Returns: Picked expression
Return type: _pickerExpression
-
dynet.
pick_batch
(e, indices, dim=0)¶ Batched pick.
Pick elements from multiple batches.
Parameters: - e (Expression) – Expression to pick from
- indices (list) – Indices to pick
- dim (number) – Dimension to pick from (default: 0)
Returns: Picked expression
Return type: _pickerBatchExpression
-
dynet.
pickrange
(x, s, e)¶
-
dynet.
pick_batch_elem
(x, v)¶ Pick batch element.
Pick batch element from a batched expression. For a Tensor with 3 batch elements:
\[\begin{split}\begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix}\end{split}\]pick_batch_elem(t, 1)
will return a Tensor of\[\begin{split}\begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\end{split}\]Parameters: - x (dynet.Expression) – Input expression
- v (int) – The index of the batch element to be picked.
Returns: The expression of picked batch element. The picked element is a tensor whose batch dimension equals to one.
Return type:
-
dynet.
pick_batch_elems
(x, vs)¶ Pick batch element.
Pick batch element from a batched expression. For a Tensor with 3 batch elements:
\[\begin{split}\begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix}\end{split}\]pick_batch_elems(t, [2, 3])
will return a Tensor of\[\begin{split}\begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix}\\ \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix}\end{split}\]Parameters: - x (dynet.Expression) – Input expression
- vs (list) – A list of indices of the batch elements to be picked.
Returns: The expression of picked batch elements. The batch elements is a tensor whose batch dimension equals to the size of list v.
Return type:
-
dynet.
reshape
(x, d, batch_size=1)¶ Reshape to another size
This node reshapes a tensor to another size, without changing the underlying layout of the data. The layout of the data in DyNet is column-major, so if we have a 3x4 matrix :
\[\begin{split}\begin{pmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} \\ \end{pmatrix}\end{split}\]and transform it into a 2x6 matrix, it will be rearranged as:
\[\begin{split}\begin{pmatrix} x_{1,1} & x_{3,1} & x_{2,2} & x_{1,3} & x_{3,3} & x_{2,4} \\ x_{2,1} & x_{1,2} & x_{3,2} & x_{2,3} & x_{1,4} & x_{3,4} \\ \end{pmatrix}\end{split}\]Note: This is O(1) for forward, and O(n) for backward.
Parameters: - x (dynet.Expression) – Input expression
- d (tuple) – New dimension
Keyword Arguments: batch_size (int) – New batch size (default: (1))
Returns: The reshaped expression
Return type:
-
dynet.
select_rows
(x, rs)¶ Select rows
Select a subset of rows of a matrix.
Parameters: - x (dynet.Expression) – Input expression
- rs (list) – The rows to extract
Returns: An expression containing the selected rows
Return type:
-
dynet.
select_cols
(x, cs)¶ Select columns
Select a subset of columns of a matrix.
Parameters: - x (dynet.Expression) – Input expression
- cs (list) – The columns to extract
Returns: An expression containing the selected columns
Return type:
-
dynet.
concatenate_cols
(xs)¶ Concatenate columns
Perform a concatenation of the columns in multiple expressions. All expressions must have the same number of rows.
Parameters: xs (list) – A list of expressions Returns: The expression with the columns concatenated Return type: dynet.Expression
-
dynet.
concatenate
(xs, d=0)¶ Concatenate
Perform a concatenation of multiple expressions along a particular dimension. All expressions must have the same dimensions except for the dimension to be concatenated (rows by default).Parameters: - xs (list) – A list of expressions
- d – The dimension along with to perform concatenation
Returns: The expression concatenated along the particular dimension
Return type:
-
dynet.
concatenate_to_batch
(xs)¶ Concatenate list of expressions to a single batched expression
Perform a concatenation of several expressions along the batch dimension. All expressions must have the same shape except for the batch dimension.
Parameters: xs (list) – A list of expressions of same dimension (except batch size) Returns: The expression with the batch dimensions concatenated Return type: dynet.Expression
-
dynet.
nobackprop
(x)¶ Prevent backprop
This node has no effect on the forward pass, but prevents gradients from flowing backward during the backward pass. This is useful when there’s a subgraph for which you don’t want loss passed back to the parameters.
Parameters: x (dynet.Expression) – Input expression Returns: An output expression containing the same as input (only effects on backprop process) Return type: dynet.Expression
-
dynet.
flip_gradient
(x)¶ Negative backprop
This node has no effect on the forward pass, but takes negative on backprop process. This operation is widely used in adversarial networks.
Parameters: x (dynet.Expression) – Input expression Returns: An output expression containing the same as input (only effects on backprop process) Return type: dynet.Expression
Noise operations¶
-
dynet.
noise
(x, stddev) Additive gaussian noise
Add gaussian noise to an expression.
Parameters: - x (dynet.Expression) – Input expression
- stddev (number) – The standard deviation of the gaussian
Returns: \(y\sim\mathcal N(x,\texttt{stddev})\)
Return type:
-
dynet.
dropout
(x, p)¶ Dropout
With a fixed probability, drop out (set to zero) nodes in the input expression, and scale the remaining nodes by 1/p. Note that there are two kinds of dropout:
- Regular dropout: where we perform dropout at training time and then scale outputs by p at test time.
- Inverted dropout: where we perform dropout and scaling at training time, and do not need to do anything at test time.
DyNet implements the latter, so you only need to apply dropout at training time, and do not need to perform scaling and test time.
Parameters: - x (dynet.Expression) – Input expression
- p (number) – The dropout probability
Returns: The dropped out expression \(y=\frac{1}{1-\texttt{p}}x\circ z, z\sim\text{Bernoulli}(1-\texttt{p})\)
Return type:
-
dynet.
dropout_dim
(x, d, p)¶ Dropout along one dimension
Identical to the dropout operation except the dropout mask is the same across one dimension. Use this if you want to drop columns or lines in a matrix for example
For now this only supports tensors of order <= 3 (with or without batch dimension)
Parameters: - x (dynet.Expression) – Input expression
- d (int) – Dimension along which to drop
- p (number) – The dropout probability
Returns: The dropped expression
Return type:
-
dynet.
dropout_batch
(x, p)¶ Dropout entire elements of a minibatch
Identical to the dropout operation except entire batch elements are dropped
Parameters: - x (dynet.Expression) – Input expression
- p (number) – The dropout probability
Returns: The dropped expression
Return type:
-
dynet.
block_dropout
(x, p)¶ Block dropout
Identical to the dropout operation, but either drops out all or no values in the expression, as opposed to making a decision about each value individually.
Parameters: - x (dynet.Expression) – Input expression
- p (number) – The dropout probability
Returns: The block dropout expression
Return type:
Linear algebra operations¶
-
dynet.
affine_transform
(exprs)¶ Affine transform
This performs an affine transform over an arbitrary (odd) number of expressions held in the input initializer list xs. The first expression is the “bias,” which is added to the expression as-is. The remaining expressions are multiplied together in pairs, then added. A very common usage case is the calculation of the score for a neural network layer (e.g. \(b + Wz\)) where b is the bias, W is the weight matrix, and z is the input. In this case
xs[0] = b
,xs[1] = W
, andxs[2] = z
.Parameters: exprs (list) – A list containing an odd number of expressions Returns: An expression equal to: xs[0] + xs[1]*xs[2] + xs[3]*xs[4] + ...
Return type: dynet.Expression
-
dynet.
dot_product
(x, y)¶ Dot Product
Calculate the dot product \(x^Ty=\sum_i x_iy_i\)
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: \(x^Ty=\sum_i x_iy_i\)
Return type:
-
dynet.
transpose
(x, dims=[1, 0])¶ Transpose a matrix
Get the transpose of the matrix, or if dims is specified shuffle the dimensions arbitrarily.
Note: This is O(1) if either the row or column dimension is 1, and O(n) otherwise.
Parameters: - x (dynet.Expression) – Input expression
- dims (list) – The dimensions to swap. The ith dimension of the output will be equal to the dims[i] dimension of the input. dims must have the same number of dimensions as x.
Returns: \(x^T\) / the shuffled expression
Return type:
-
dynet.
inverse
(x)¶ Matrix Inverse
Takes the inverse of a matrix (not implemented on GPU yet, although contributions are welcome: issue). Note that back-propagating through an inverted matrix can also be the source of stability problems sometimes.
Parameters: x (dynet.Expression) – Input expression Returns: Inverse of x Return type: dynet.Expression
-
dynet.
trace_of_product
(x, y)¶ Trace of Matrix Product
Takes the trace of the product of matrices. (not implemented on GPU yet, although contributions are welcome: issue).
Parameters: - x (dynet.Expression) – The first input expression
- y (Expression) – The second input expression
Returns: \(\text{Tr}(xy)\)
Return type:
-
dynet.
logdet
(x)¶ Log determinant
Takes the log of the determinant of a matrix. (not implemented on GPU yet, although contributions are welcome: issue).
Parameters: x (dynet.Expression) – Input expression Returns: \(\log(\vert x\vert)\) Return type: dynet.Expression
Convolution/Pooling operations¶
-
dynet.
conv2d
(x, f, stride, is_valid=True)¶ 2D convolution without bias
2D convolution operator without bias parameters.
VALID
andSAME
convolutions are supported.Think about when stride is 1, the distinction:
SAME
: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.VALID
: output size shrinks byfilter_size - 1
, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume
- Input feature maps:
XH x XW x XC x N
- Filters:
FH x FW x XC x FC
- Strides:
strides[0]
andstrides[1]
are row (h
) and col (w
) stride, respectively.
For the
SAME
convolution: the output height (YH
) and width (YW
) are computed as:YH = ceil(float(XH) / float(strides[0]))
YW = ceil(float(XW) / float(strides[1]))
and the paddings are computed as:
pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
pad_top = pad_along_height / 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width / 2
pad_right = pad_along_width - pad_left
For the
VALID
convolution: the output height (:code`YH`) and width (YW
) are computed as:YH = ceil(float(XH - FH + 1) / float(strides[0]))
YW = ceil(float(XW - FW + 1) / float(strides[1]))
and the paddings are always zeros.
Parameters: - x (dynet.Expression) – The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
- f (dynet.Expression) – 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensor
- stride (list) – the row and column strides
Keyword Arguments: is_valid (bool) – ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’) (default: (True))
Returns: The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Return type:
-
dynet.
conv2d_bias
(x, f, b, stride, is_valid=True)¶ 2D convolution with bias
2D convolution operator with bias parameters.
VALID
andSAME
convolutions are supported.Think about when stride is 1, the distinction:
SAME
: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.VALID
: output size shrinks byfilter_size - 1
, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume
- Input feature maps:
XH x XW x XC x N
- Filters:
FH x FW x XC x FC
- Strides:
strides[0]
andstrides[1]
are row (h
) and col (w
) stride, respectively.
For the
SAME
convolution: the output height (YH
) and width (YW
) are computed as:YH = ceil(float(XH) / float(strides[0]))
YW = ceil(float(XW) / float(strides[1]))
and the paddings are computed as:
pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
pad_top = pad_along_height / 2
pad_bottom = pad_along_height - pad_top
pad_left = pad_along_width / 2
pad_right = pad_along_width - pad_left
For the
VALID
convolution: the output height (:code`YH`) and width (YW
) are computed as:YH = ceil(float(XH - FH + 1) / float(strides[0]))
YW = ceil(float(XW - FW + 1) / float(strides[1]))
and the paddings are always zeros.
Parameters: - x (dynet.Expression) – The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
- f (dynet.Expression) – 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensor
- b (dynet.Expression) – The bias (1D: Ci)
- stride (list) – the row and column strides
Keyword Arguments: is_valid (bool) – ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’) (default: (True))
Returns: The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Return type:
-
dynet.
maxpooling2d
(x, ksize, stride, is_valid=True)¶ 2D maxpooling
2D maxpooling operator.
VALID
andSAME
maxpooling are supported.Parameters: - x (dynet.Expression) – The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
- ksize (list) – the max pooling 2d window size
- stride (list) – the row and column strides
Keyword Arguments: is_valid (bool) – ‘VALID’ or ‘SAME’, default is True (‘VALID’) (default: (True))
Returns: The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Return type:
-
dynet.
filter1d_narrow
(x, y)¶ [summary]
[description]
Parameters: - x (dynet.Expression) – The first input expression
- y (dynet.Expression) – The second input expression
Returns: TODO
Return type:
-
dynet.
kmax_pooling
(x, k, d=1)¶ Kmax-pooling operation
Select out k maximum values along a given dimension, in the same order as they appear. This will result in the size of the given dimension being changed to k.
Parameters: - x (dynet.Expression) –
- k (unsigned) – Number of maximum values to retrieve along the given dimension
Keyword Arguments: d (unsigned) – Dimension on which to perform kmax-pooling (default: (1))
Returns: Return type:
-
dynet.
circ_conv
(u, v)¶ Circular convolution
Calculate the circular convolution \([u * v]_k=\sum_i u_iv_{(k-i) \mod d}\)
Parameters: - u (dynet.Expression) – The first input expression
- v (dynet.Expression) – The second input expression
Returns: \(u * v\)
Return type:
-
dynet.
circ_corr
(u, v)¶ Circular correlation
Calculate the circular correlation \([u \star v]_k=\sum_i u_iv_{(i + k) \mod d}\)
Parameters: - u (dynet.Expression) – The first input expression
- v (dynet.Expression) – The second input expression
Returns: \(u \star v\)
Return type:
Tensor operations¶
Remark: Compiling the contraction operations takes a lot of time with CUDA. For this reason, only the CPU implementation is compiled by default. If you need those operations, you need to un-comment this line in the source before compiling. TODO: make this simpler.
-
dynet.
contract3d_1d
(x, y)¶ Contracts a rank 3 tensor and a rank 1 tensor into a rank 2 tensor
The resulting tensor \(z\) has coordinates \(z_ij = \sum_k x_{ijk} y_k\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
Returns: Matrix dynet.Expression
-
dynet.
contract3d_1d_bias
(x, y, b)¶ Same as
contract3d_1d
with an additional bias parameterThe resulting tensor \(z\) has coordinates \(z_{ij} = b_{ij}+\sum_k x_{ijk} y_k\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
- b (dynet.Expression) – Bias vector
Returns: Matrix dynet.Expression
-
dynet.
contract3d_1d_1d
(x, y, z)¶ Contracts a rank 3 tensor and two rank 1 tensor into a rank 1 tensor
This is the equivalent of calling
contract3d_1d
and then performing a matrix vector multiplication.The resulting tensor \(t\) has coordinates \(t_i = \sum_{j,k} x_{ijk} y_k z_j\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
- z (dynet.Expression) – Vector
Returns: Vector dynet.Expression
-
dynet.
contract3d_1d_1d_bias
(x, y, z, b)¶ Same as
contract3d_1d_1d
with an additional bias parameterThis is the equivalent of calling
contract3d_1d
and then performing an affine transform.The resulting tensor \(t\) has coordinates \(t_i = b_i + \sum_{j,k} x_{ijk} y_k z_j\)
Parameters: - x (dynet.Expression) – Rank 3 tensor
- y (dynet.Expression) – Vector
- z (dynet.Expression) – Vector
- b (dynet.Expression) – Bias vector
Returns: Vector dynet.Expression
Normalization operations¶
-
dynet.
layer_norm
(x, g, b)¶ Layer normalization
Performs layer normalization :
\[\begin{split}\begin{split} \mu &= \frac 1 n \sum_{i=1}^n x_i\\ \sigma &= \sqrt{\frac 1 n \sum_{i=1}^n (x_i-\mu)^2}\\ y&=\frac {\boldsymbol{g}} \sigma \circ (\boldsymbol{x}-\mu) + \boldsymbol{b}\\ \end{split}\end{split}\]Reference : Ba et al., 2016
Parameters: - x (dynet.Expression) – Input expression (possibly batched)
- g (dynet.Expression) – Gain (same dimension as x, no batch dimension)
- b (dynet.Expression) – Bias (same dimension as x, no batch dimension)
Returns: An expression of the same dimension as
x
dynet.Expression
-
dynet.
weight_norm
(w, g)¶ Weight normalization
Performs weight normalization :
\[\begin{split}\begin{split} \hat{w} &= g\frac{w}{\Vert w\Vert}\\ \end{split}\end{split}\]Reference : Salimans, Kingma 2016
Parameters: - w (dynet.Expression) – Input expression (weight parameter)
- g (dynet.Expression) – Gain (scalar expression, usually also a parameter)
Returns: An expression of the same dimension as
w
dynet.Expression
Recurrent Neural Networks¶
RNN Builders¶
-
class
dynet.
_RNNBuilder
¶ -
disable_dropout
()¶ [summary]
[description]
-
initial_state
(vecs=None, update=True)¶ Get a
dynet.RNNState
This initializes a
dynet.RNNState
by loading the parameters in the computation graphParameters: - vecs (list) – Initial hidden state for each layer as a list of
dynet.Expression
s (default: {None}) - update (bool) – trainer updates internal parameters (default: {True})
Returns: dynet.RNNState
used to feed inputs/transduces sequences, etc… dynet.RNNState- vecs (list) – Initial hidden state for each layer as a list of
-
initial_state_from_raw_vectors
(vecs=None, update=True)¶ Get a
dynet.RNNState
This initializes a
dynet.RNNState
by loading the parameters in the computation graphUse this if you want to initialize the hidden states with values directly rather than expressions.
Parameters: - vecs (list) – Initial hidden state for each layer as a list of numpy arrays (default: {None})
- update (bool) – trainer updates internal parameters (default: {True})
Returns: dynet.RNNState
used to feed inputs/transduces sequences, etc… dynet.RNNState
-
param_collection
()¶
-
set_dropout
(f)¶ [summary]
[description]
Parameters: f (float) – [description]
-
-
class
dynet.
SimpleRNNBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
get_parameter_expressions
()¶ Retrieve the internal parameters expressions of the RNN
The output is a list with one item per layer. Each item is a list containing \(W_{hx},W_{hh},b_h\)
Returns: List of parameter expressions for each layer list Raises: ValueError
– This raises an expression if initial_state hasn’t been called because it requires thr parameters to be loaded in the computation graph. However it prevents the parameters to be loaded twice in the computation graph (compared todynet.parameter(rnn.get_parameters()[0][0])
for example).
-
get_parameters
()¶ Retrieve the internal parameters of the RNN
The output is a list with one item per layer. Each item is a list containing \(W_{hx},W_{hh},b_h\)
Returns: List of parameters for each layer list
-
-
class
dynet.
GRUBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
get_parameter_expressions
()¶ Retrieve the internal parameters expressions of the GRU
The output is a list with one item per layer. Each item is a list containing \(W_{zx},W_{zh},b_z,W_{rx},W_{rh},b_r,W_{hx},W_{hh},b_h\)
Returns: List of parameter expressions for each layer list Raises: ValueError
– This raises an expression if initial_state hasn’t been called because it requires thr parameters to be loaded in the computation graph. However it prevents the parameters to be loaded twice in the computation graph (compared todynet.parameter(rnn.get_parameters()[0][0])
for example).
-
get_parameters
()¶ Retrieve the internal parameters of the GRU
The output is a list with one item per layer. Each item is a list containing \(W_{zx},W_{zh},b_z,W_{rx},W_{rh},b_r,W_{hx},W_{hh},b_h\)
Returns: List of parameters for each layer list
-
-
class
dynet.
VanillaLSTMBuilder
(layers, input_dim, hidden_dim, model, ln_lstm=False, forget_bias=1.0)¶ Bases:
dynet._RNNBuilder
VanillaLSTM allows to create an “standard” LSTM, ie with decoupled input and forget gate and no peepholes connections
This cell runs according to the following dynamics :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]Parameters: - layers (int) – Number of layers
- input_dim (int) – Dimension of the input
- hidden_dim (int) – Dimension of the recurrent units
- model (dynet.ParameterCollection) – ParameterCollection to hold the parameters
- ln_lstm (bool) – Whether to use layer normalization
- forget_bias (float) – value to use as forget gate bias(default 1.0)
-
get_parameter_expressions
()¶ Retrieve the internal parameters expressions of the VanillaLSTM
The output is a list with one item per layer. Each item is a list containing \(W_x,W_h,b\) where \(W_x,W_h\) are stacked version of the individual gates matrices:
h/x +------+ | | i | | +------+ | | f | | +------+ | | o | | +------+ | | c | | +------+
Returns: List of parameter expressions for each layer list Raises: ValueError
– This raises an expression if initial_state hasn’t been called because it requires thr parameters to be loaded in the computation graph. However it prevents the parameters to be loaded twice in the computation graph (compared todynet.parameter(rnn.get_parameters()[0][0])
for example).
-
get_parameters
()¶ Retrieve the internal parameters of the VanillaLSTM
The output is a list with one item per layer. Each item is a list containing \(W_x,W_h,b\) where \(W_x,W_h\) are stacked version of the individual gates matrices:
h/x +------+ | | i | | +------+ | | f | | +------+ | | o | | +------+ | | c | | +------+
Returns: List of parameters for each layer list
-
set_dropout_masks
(batch_size=1)¶ Set dropout masks at the beginning of a sequence for a specific batch size
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
You need to call this __AFTER__ calling initial_state
Parameters: batch_size (int) – Batch size (default: {1})
-
set_dropouts
(d, d_r)¶ Set the dropout rates
The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016
More specifically, dropout masks \(\mathbf{z_x}\sim \text(1-d_x)\), \(\mathbf{z_h}\sim \text{Bernoulli}(1-d_h)\) are sampled at the start of each sequence.
The dynamics of the cell are then modified to :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = anh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
Parameters: - d (number) – Dropout rate \(d_x\) for the input \(x_t\)
- d_r (number) – Dropout rate \(d_x\) for the output \(h_t\)
-
class
dynet.
CompactVanillaLSTMBuilder
(layers, input_dim, hidden_dim, model)¶ Bases:
dynet._RNNBuilder
CompactVanillaLSTM allows to create an “standard” LSTM, ie with decoupled input and forget gate and no peepholes connections
This cell runs according to the following dynamics :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]Parameters: - layers (int) – Number of layers
- input_dim (int) – Dimension of the input
- hidden_dim (int) – Dimension of the recurrent units
- model (dynet.ParameterCollection) – ParameterCollection to hold the parameters
-
get_parameter_expressions
()¶ Retrieve the internal parameters expressions of the CompactVanillaLSTM
The output is a list with one item per layer. Each item is a list containing \(W_x,W_h,b\) where \(W_x,W_h\) are stacked version of the individual gates matrices:
h/x +------+ | | i | | +------+ | | f | | +------+ | | o | | +------+ | | c | | +------+
Returns: List of parameter expressions for each layer list Raises: ValueError
– This raises an expression if initial_state hasn’t been called because it requires thr parameters to be loaded in the computation graph. However it prevents the parameters to be loaded twice in the computation graph (compared todynet.parameter(rnn.get_parameters()[0][0])
for example).
-
get_parameters
()¶ Retrieve the internal parameters of the CompactVanillaLSTM
The output is a list with one item per layer. Each item is a list containing \(W_x,W_h,b\) where \(W_x,W_h\) are stacked version of the individual gates matrices:
h/x +------+ | | i | | +------+ | | f | | +------+ | | o | | +------+ | | c | | +------+
Returns: List of parameters for each layer list
-
set_dropout_masks
(batch_size=1)¶ Set dropout masks at the beginning of a sequence for a specific batch size
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
You need to call this __AFTER__ calling initial_state
Parameters: batch_size (int) – Batch size (default: {1})
-
set_dropouts
(d, d_r)¶ Set the dropout rates
The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016
More specifically, dropout masks \(\mathbf{z_x}\sim \text(1-d_x)\), \(\mathbf{z_h}\sim \text{Bernoulli}(1-d_h)\) are sampled at the start of each sequence.
The dynamics of the cell are then modified to :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = anh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
Parameters: - d (number) – Dropout rate \(d_x\) for the input \(x_t\)
- d_r (number) – Dropout rate \(d_x\) for the output \(h_t\)
-
set_weightnoise
(std)¶ Set the gaussian weight noise
Parameters: std (number) – Standard deviation of weight noise
-
class
dynet.
CoupledLSTMBuilder
¶ Bases:
dynet._RNNBuilder
CoupledLSTMBuilder creates an LSTM unit with coupled input and forget gate as well as peepholes connections.
More specifically, here are the equations for the dynamics of this cell :
\[\begin{split}\begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+W_{ic}c_{t-1}+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+W_{oc}c_{t}+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split}\end{split}\]-
get_parameter_expressions
()¶ Retrieve the internal parameters expressions of the LSTM
The output is a list with one item per layer. Each item is a list containing \(W_{ix},W_{ih},W_{ic},b_i,W_{ox},W_{oh},W_{oc},b_o,W_{cx},W_{ch},b_c\)
Returns: List of parameter expressions for each layer list Raises: ValueError
– This raises an expression if initial_state hasn’t been called because it requires thr parameters to be loaded in the computation graph. However it prevents the parameters to be loaded twice in the computation graph (compared todynet.parameter(rnn.get_parameters()[0][0])
for example).
-
get_parameters
()¶ Retrieve the internal parameters of the LSTM
The output is a list with one item per layer. Each item is a list containing \(W_{ix},W_{ih},W_{ic},b_i,W_{ox},W_{oh},W_{oc},b_o,W_{cx},W_{ch},b_c\)
Returns: List of parameters for each layer list
-
-
class
dynet.
FastLSTMBuilder
¶ Bases:
dynet._RNNBuilder
[summary]
[description]
-
get_parameter_expressions
()¶ Retrieve the internal parameters expressions of the FastLSTM
The output is a list with one item per layer. Each item is a list containing \(W_{ix},W_{ih},W_{ic},b_i,W_{ox},W_{oh},W_{oc},b_o,W_{cx},W_{ch},b_c\)
Returns: List of parameter expressions for each layer list Raises: ValueError
– This raises an expression if initial_state hasn’t been called because it requires thr parameters to be loaded in the computation graph. However it prevents the parameters to be loaded twice in the computation graph (compared todynet.parameter(rnn.get_parameters()[0][0])
for example).
-
get_parameters
()¶ Retrieve the internal parameters of the FastLSTM
The output is a list with one item per layer. Each item is a list containing \(W_{ix},W_{ih},W_{ic},b_i,W_{ox},W_{oh},W_{oc},b_o,W_{cx},W_{ch},b_c\)
Returns: List of parameters for each layer list
-
-
class
dynet.
BiRNNBuilder
(num_layers, input_dim, hidden_dim, model, rnn_builder_factory, builder_layers=None)¶ Bases:
object
Builder for BiRNNs that delegates to regular RNNs and wires them together.
builder = BiRNNBuilder(1, 128, 100, model, LSTMBuilder) [o1,o2,o3] = builder.transduce([i1,i2,i3])-
add_inputs
(es)¶ returns the list of state pairs (stateF, stateB) obtained by adding inputs to both forward (stateF) and backward (stateB) RNNs. Does not preserve the internal state after adding the inputs. :param es: a list of Expression :type es: list
see also transduce(xs)
code:.transduce(xs) is different from .add_inputs(xs) in the following way:
- code:.add_inputs(xs) returns a list of RNNState pairs. RNNState objects can be
- queried in various ways. In particular, they allow access to the previous state, as well as to the state-vectors (h() and s() )
.transduce(xs)
returns a list of Expression. These are just the output- expressions. For many cases, this suffices. transduce is much more memory efficient than add_inputs.
-
transduce
(es)¶ returns the list of output Expressions obtained by adding the given inputs to the current state, one by one, to both the forward and backward RNNs, and concatenating.
@param es: a list of Expression
see also add_inputs(xs)
.transduce(xs) is different from .add_inputs(xs) in the following way:
- .add_inputs(xs) returns a list of RNNState pairs. RNNState objects can be
- queried in various ways. In particular, they allow access to the previous state, as well as to the state-vectors (h() and s() )
- .transduce(xs) returns a list of Expression. These are just the output
- expressions. For many cases, this suffices. transduce is much more memory efficient than add_inputs.
-
RNN state¶
-
class
dynet.
RNNState
¶ This is the main class for working with RNNs / LSTMs / GRUs. Request an RNNState initial_state() from a builder, and then progress from there.
-
add_input
(x)¶ This computes \(h_t = \text{RNN}(x_t)\)
Parameters: x (dynet.Expression) – Input expression Returns: New RNNState dynet.RNNState
-
add_inputs
(xs)¶ Returns the list of states obtained by adding the given inputs to the current state, one by one.
see also
transduce(xs)
.transduce(xs)
is different from.add_inputs(xs)
in the following way:.add_inputs(xs)
returns a list of RNNState. RNNState objects can be- queried in various ways. In particular, they allow access to the previous
state, as well as to the state-vectors (
h()
ands()
)
.transduce(xs)
returns a list of Expression. These are just the output- expressions. For many cases, this suffices.
transduce
is much more memory efficient thanadd_inputs
.Parameters: xs (list) – list of input expressions Returns: New RNNState dynet.RNNState
-
b
()¶ Get the underlying RNNBuilder
In case you need to set dropout or other stuff.
Returns: Underlying RNNBuilder dynet.RNNBuilder
-
h
()¶ tuple of expressions representing the output of each hidden layer of the current step. the actual output of the network is at h()[-1].
-
prev
()¶ Gets previous RNNState
In case you need to rewind
-
s
()¶ tuple of expressions representing the hidden state of the current step.
For SimpleRNN, s() is the same as h() For LSTM, s() is a series of of memory vectors, followed the series followed by the series returned by h():
(c[1],...,c[num_layers], h[1],...,h[num_layers])
-
set_h
(es=None)¶ Manually set the output \(h_t\)
Parameters: es (list) – List of expressions, one for each layer (default: {None}) Returns: New RNNState dynet.RNNState
-
set_s
(es=None)¶ Manually set the hidden states
This is different from
set_h
because, for LSTMs for instance this also sets the cell state. The format is[new_c[0],...,new_c[n],new_h[0],...,new_h[n]]
Parameters: es (list) – List of expressions, in this format : [new_c[0],...,new_c[n],new_h[0],...,new_h[n]]
(default: {None})Returns: New RNNState dynet.RNNState
-
transduce
(xs)¶ returns the list of output Expressions obtained by adding the given inputs to the current state, one by one.
see also
add_inputs(xs)
.transduce(xs)
is different from.add_inputs(xs)
in the following way:.add_inputs(xs)
returns a list of RNNState. RNNState objects can be- queried in various ways. In particular, they allow access to the previous
state, as well as to the state-vectors (
h()
ands()
)
.transduce(xs)
returns a list of Expression. These are just the output- expressions. For many cases, this suffices.
transduce
is much more memory efficient thanadd_inputs
.Parameters: xs (list) – list of input expressions Returns: New RNNState dynet.RNNState
-
Softmax Builders¶
-
class
dynet.
SoftmaxBuilder
¶ Interface for building softmax layers
A softmax layer returns a probability distribution over \(C\) classes given a vector \(h\in\mathbb R^d\), with
\[p(c)\propto \exp(W_i^Th + b_i)\ \forall i\in\{1\ldots C\}\]Where \(W\in \mathbb R^{C\times d}, b \in \mathbb R^C\)
-
full_log_distribution
(x, update=True)¶ Returns an Expression representing a vector the size of the number of classes.
The ith dimension gives \(\log p(c_i | x)\). This function may be SLOW. Avoid if possible.
Parameters: - x (dynet.Expression) – Input vector
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Vector of \(\log(p(c\mid x)\) dynet.Expression
-
full_logits
(x, update=True)¶ Returns the logits (before application of the softmax)
The ith dimension gives \(W_i^Tx + b_i\)
Parameters: - x (dynet.Expression) – Input vector
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Expression for the logits dynet.Expression
-
neg_log_softmax
(x, c, update=True)¶ Negative log probability of a class
Given class \(c\) and vector \(x\), this returns \(-\log(p(c \mid x))\)
Parameters: - x (dynet.Expression) – Input vector
- c (unsigned) – Class id
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Log probability of given class dynet.Expression
-
neg_log_softmax_batch
(x, c, update=True)¶ Batched version of
neg_log_softmax
Parameters: - x (dynet.Expression) – Input vector (batched)
- c (list) – list of class ids (one per batch element)
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Log probability of given class dynet.Expression
-
param_collection
()¶ Returns the ParameterCollection containing the softmax parameters
The first parameter in the parametercollection is the weight matrix, the second is the biases (if any)
Returns: Subcollection holding the parameters ParameterCollection
-
sample
(x)¶ Sample from the softmax distribution
Parameters: x (dynet.Expression) – Input vector Returns: Sampled class int
-
-
class
dynet.
StandardSoftmaxBuilder
¶ Bases:
dynet.SoftmaxBuilder
This class implements the standard Softmax
-
class
dynet.
ClassFactoredSoftmaxBuilder
¶ Bases:
dynet.SoftmaxBuilder
Class factored softmax
Each class is separated into a subclass, ie \(p(i\mid h)=p(i\mid h, c) p(c\mid h)\) where \(c\) is a class and \(i\) a subclass
-
class_log_distribution
(x, update=True)¶ Get log distribution over classes
Parameters: - x (dynet.Expression) – Input vector
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Vector of \(\log(p(c\mid x)\) dynet.Expression
-
class_logits
(x, update=True)¶ Returns the logits over classes
Parameters: - x (dynet.Expression) – Input vector
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Expression for the logits dynet.Expression
-
subclass_log_distribution
(x, classid, update=True)¶ Get log distribution over subclasses of class
Parameters: - x (dynet.Expression) – Input vector
- classid (int) – class index
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Vector of \(\log(p(i\mid x, \texttt{classid})\) dynet.Expression
-
subclass_logits
(x, classid, update=True)¶ Logits over subclasses of class
Parameters: - x (dynet.Expression) – Input vector
- classid (int) – class index
- update (bool) – Whether to update the parameters or not (default: {True})
Returns: Expression for the logits dynet.Expression
-
Optimizers¶
-
class
dynet.
Trainer
¶ Generic trainer
-
learning_rate
¶ number – Global learning rate for all parameters
-
get_clip_threshold
()¶ Get clipping threshold
Returns: Gradient clipping threshold Return type: number
-
restart
(learning_rate=None)¶ Restarts the optimizer
Clears all momentum values and assimilate (if applicable)
Parameters: learning_rate (number) – (Optional) resets the learning rate
-
set_clip_threshold
(thr)¶ Set clipping thershold
To deactivate clipping, set the threshold to be <=0
Parameters: thr (number) – Clipping threshold
-
set_sparse_updates
(su)¶ Sets updates to sparse updates
DyNet trainers support two types of updates for lookup parameters, sparse and dense. Sparse updates are the default. They have the potential to be faster, as they only touch the parameters that have non-zero gradients. However, they may not always be faster (particulary on GPU with mini-batch training), and are not precisely numerically correct for some update rules such as MomentumTrainer and AdamTrainer. Thus, if you set this variable to false, the trainer will perform dense updates and be precisely correct, and maybe faster sometimes. :param su: flag to activate/deactivate sparse updates :type su: bool
-
status
()¶ Outputs information about the trainer in the stderr
(number of updates since last call, number of clipped gradients, learning rate, etc…)
-
update
()¶ Update the parameters
The update equation is different for each trainer, check the online c++ documentation for more details on what each trainer does
-
update_epoch
(r)¶ DEPRECATED: do not use.
-
update_subset
(updated_params, updated_lookups)¶ Update a subset of parameters
Only use this in last resort, a more elegant way to update only a subset of parameters is to use the “update” keyword in dy.parameter or Parameter.expr() to specify which parameters need to be updated __during the creation of the computation graph__
Parameters: - updated_params (list) – Indices of parameters to update
- updated_lookups (list) – Indices of lookup parameters to update
-
-
class
dynet.
SimpleSGDTrainer
¶ Bases:
dynet.Trainer
Stochastic gradient descent trainer
This trainer performs stochastic gradient descent, the goto optimization procedure for neural networks.
Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained Keyword Arguments: learning_rate (number) – Initial learning rate (default: 0.1)
-
class
dynet.
CyclicalSGDTrainer
¶ Bases:
dynet.Trainer
This trainer performs stochastic gradient descent with a cyclical learning rate as proposed in Smith, 2015.
This uses a triangular function with optional exponential decay.
More specifically, at each update, the learning rate \(\eta\) is updated according to :
\[\begin{split} \begin{split} \text{cycle} &= \left\lfloor 1 + \frac{\texttt{it}}{2 \times\texttt{step_size}} \right\rfloor\\ x &= \left\vert \frac{\texttt{it}}{\texttt{step_size}} - 2 \times \text{cycle} + 1\right\vert\\ \eta &= \eta_{\text{min}} + (\eta_{\text{max}} - \eta_{\text{min}}) \times \max(0, 1 - x) \times \gamma^{\texttt{it}}\\ \end{split}\end{split}\]Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained
Keyword Arguments: - learning_rate_min (number) – Lower learning rate (default: {0.01})
- learning_rate_max (number) – Upper learning rate (default: {0.1})
- step_size (number) – Period of the triangular function in number of iterations (__not__ epochs). According to the original paper, this should be set around (2-8) x (training iterations in epoch) (default: {2000})
- gamma (number) – Learning rate upper bound decay parameter (default: {0.0})
-
class
dynet.
MomentumSGDTrainer
¶ Bases:
dynet.Trainer
Stochastic gradient descent with momentum
This is a modified version of the SGD algorithm with momentum to stablize the gradient trajectory.
Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained
Keyword Arguments: - learning_rate (number) – Initial learning rate (default: 0.1)
- mom (number) – Momentum (default: 0.9)
-
class
dynet.
AdagradTrainer
¶ Bases:
dynet.Trainer
Adagrad optimizer
The adagrad algorithm assigns a different learning rate to each parameter.
Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained
Keyword Arguments: - learning_rate (number) – Initial learning rate (default: 0.1)
- eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-20)
-
class
dynet.
AdadeltaTrainer
¶ Bases:
dynet.Trainer
AdaDelta optimizer
The AdaDelta optimizer is a variant of Adagrad aiming to prevent vanishing learning rates.
Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained
Keyword Arguments: - eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-6)
- rho (number) – Update parameter for the moving average of updates in the numerator (default: 0.95)
-
class
dynet.
RMSPropTrainer
¶ Bases:
dynet.Trainer
RMSProp optimizer
The RMSProp optimizer is a variant of Adagrad where the squared sum of previous gradients is replaced with a moving average with parameter rho.
Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained
Keyword Arguments: - learning_rate (number) – Initial learning rate (default: 0.001)
- eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-8)
- rho (number) – Update parameter for the moving average (rho = 0 is equivalent to using Adagrad) (default: 0.9)
-
class
dynet.
AdamTrainer
¶ Bases:
dynet.Trainer
Adam optimizer
The Adam optimizer is similar to RMSProp but uses unbiased estimates of the first and second moments of the gradient
Parameters: m (dynet.ParameterCollection) – ParameterCollection to be trained
Keyword Arguments: - alpha (number) – Initial learning rate (default: 0.001)
- beta_1 (number) – Moving average parameter for the mean (default: 0.9)
- beta_2 (number) – Moving average parameter for the variance (default: 0.999)
- eps (number) – Epsilon parameter to prevent numerical instability (default: 1e-8)
MultiDevice¶
-
dynet.
to_device
(e, device_str)¶ Copy Expression’s values between devices. Creates a new expression with e’s values on device device_str.
Parameters: - e (dynet.Expression) – Expression
- device_str (string) – a device name
Returns: dynet.Expression
C++ Reference manual¶
Core functionalities¶
Computation Graph¶
The ComputationGraph is the workhorse of dynet. From the Dynet technical report :
[The] computation graph represents symbolic computation, and the results of the computation are evaluated lazily: the computation is only performed once the user explicitly asks for it (at which point a “forward” computation is triggered). Expressions that evaluate to scalars (i.e. loss values) can also be used to trigger a “backward” computation, computing the gradients of the computation with respect to the parameters.
-
int dynet
::
get_number_of_active_graphs
()¶ Gets the number of active graphs.
This is 0 or 1, you can’t create more than one graph at once
- Return
- Number of active graphs
-
unsigned dynet
::
get_current_graph_id
()¶ Get id of the current active graph.
This can help check whether a graph is stale
- Return
- Id of the current graph
-
struct dynet
::
ComputationGraph
¶ - #include <dynet.h>
Computation graph where nodes represent forward and backward intermediate values, and edges represent functions of multiple values.
To represent the fact that a function may have multiple arguments, edges have a single head and 0, 1, 2, or more tails. (Constants, inputs, and parameters are represented as functions of 0 parameters.) Example: given the function z = f(x, y), z, x, and y are nodes, and there is an edge representing f with which points to the z node (i.e., its head), and x and y are the tails of the edge. You shouldn’t need to use most methods from the ComputationGraph except for
backward
since most of them are available directly from the Expression class.Public Functions
-
dynet::ComputationGraph
ComputationGraph
()¶ Default constructor.
-
VariableIndex dynet::ComputationGraph
add_input
(real s, Device *device)¶ Add scalar input.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
s
: Real numberdevice
: The device to place input value
-
VariableIndex dynet::ComputationGraph
add_input
(const real *ps, Device *device)¶ Add scalar input by pointer.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
ps
: Pointer to a real numberdevice
: The device to place input value
-
VariableIndex dynet::ComputationGraph
add_input
(const Dim &d, const std::vector<float> &data, Device *device)¶ Add multidimentsional input.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
d
: Desired shape of the inputdata
: Input data (as a 1 dimensional array)data
: The data points corresponding to each indexdevice
: The device to place input value
-
VariableIndex dynet::ComputationGraph
add_input
(const Dim &d, const std::vector<float> *pdata, Device *device)¶ Add multidimentsional input by pointer.
The computational network will pull inputs in from the user’s data structures and make them available to the computation
- Return
- The index of the created variable
- Parameters
d
: Desired shape of the inputpdata
: Pointer to the input data (as a 1 dimensional array)device
: The device to place input value
-
VariableIndex dynet::ComputationGraph
add_input
(const Dim &d, const std::vector<unsigned int> &ids, const std::vector<float> &data, Device *device, float defdata = 0.f)¶ Add sparse input.
The computational network will pull inputs in from the user’s data structures and make them available to the computation. Represents specified (not learned) inputs to the network in sparse array format, with an optional default value.
- Return
- The index of the created variable
- Parameters
d
: Desired shape of the inputids
: The indexes of the data points to updatedata
: The data points corresponding to each indexdevice
: The device to place input valuedefdata
: The default data with which to set the unspecified data points
-
VariableIndex dynet::ComputationGraph
add_parameters
(Parameter p)¶ Add a parameter to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: Parameter to be added
-
VariableIndex dynet::ComputationGraph
add_parameters
(LookupParameter p)¶ Add a full matrix of lookup parameters to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: LookupParameter to be added
-
VariableIndex dynet::ComputationGraph
add_const_parameters
(Parameter p)¶ Add a parameter to the computation graph (but don’t update)
- Return
- The index of the created variable
- Parameters
p
: Parameter to be added
-
VariableIndex dynet::ComputationGraph
add_const_parameters
(LookupParameter p)¶ Add a full matrix of lookup parameter to the computation graph (but don’t update)
- Return
- The index of the created variable
- Parameters
p
: LookupParameter to be added
-
VariableIndex dynet::ComputationGraph
add_lookup
(LookupParameter p, const unsigned *pindex)¶ Add a lookup parameter to the computation graph.
Use pindex to point to a memory location where the index will live that the caller owns
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindex
: Pointer to the index to lookup
-
VariableIndex dynet::ComputationGraph
add_lookup
(LookupParameter p, unsigned index)¶ Add a lookup parameter to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindex
: Index to lookup
-
VariableIndex dynet::ComputationGraph
add_lookup
(LookupParameter p, const std::vector<unsigned> *pindices)¶ Add lookup parameters to the computation graph.
Use pindices to point to a memory location where the indices will live that the caller owns
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindices
: Pointer to the indices to lookup
-
VariableIndex dynet::ComputationGraph
add_lookup
(LookupParameter p, const std::vector<unsigned> &indices)¶ Add lookup parameters to the computation graph.
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindices
: Indices to lookup
-
VariableIndex dynet::ComputationGraph
add_const_lookup
(LookupParameter p, const unsigned *pindex)¶ Add a lookup parameter to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindex
: Pointer to the indices to lookup
-
VariableIndex dynet::ComputationGraph
add_const_lookup
(LookupParameter p, unsigned index)¶ Add a lookup parameter to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindex
: Index to lookup
-
VariableIndex dynet::ComputationGraph
add_const_lookup
(LookupParameter p, const std::vector<unsigned> *pindices)¶ Add lookup parameters to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickpindices
: Pointer to the indices to lookup
-
VariableIndex dynet::ComputationGraph
add_const_lookup
(LookupParameter p, const std::vector<unsigned> &indices)¶ Add lookup parameters to the computation graph.
Just like add_lookup, but don’t optimize the lookup parameters
- Return
- The index of the created variable
- Parameters
p
: Lookup parameter from which to pickindices
: Indices to lookup
- template <class Function>
-
VariableIndex dynet::ComputationGraph
add_function
(const std::initializer_list<VariableIndex> &arguments)¶ Add a function to the computation graph.
This what is called when creating an expression
- Return
- The index of the output variable
- Parameters
arguments
: List of the arguments indices
- Template Parameters
Function
: Function to be applied
- template <class Function, typename… Args>
-
VariableIndex dynet::ComputationGraph
add_function
(const std::initializer_list<VariableIndex> &arguments, Args&&... side_information)¶ Add a function to the computation graph (with side information)
This what is called when creating an expression
- Return
- The index of the output variable
- Parameters
arguments
: List of the arguments indicesside_information
: Side information that is needed to compute the function
- Template Parameters
Function
: Function to be applied
-
void dynet::ComputationGraph
clear
()¶ Reset ComputationGraph to a newly created state.
[long description]
-
void dynet::ComputationGraph
checkpoint
()¶ Set a checkpoint.
-
void dynet::ComputationGraph
revert
()¶ Revert to last checkpoint.
-
Dim &dynet::ComputationGraph
get_dimension
(VariableIndex index) const¶ Get dimension of a node.
- Return
- Dimension
- Parameters
index
: Variable index of the node
-
const Tensor &dynet::ComputationGraph
forward
(const Expression &last)¶ Run complete forward pass from first node to given one, ignoring all precomputed values.
- Return
- Value of the
last
Expression after execution - Parameters
last
: Expression up to which the forward pass must be computed
-
const Tensor &dynet::ComputationGraph
forward
(VariableIndex i)¶ Run complete forward pass from first node to given one, ignoring all precomputed values.
- Return
- Value of the end Node after execution
- Parameters
i
: Variable index of the node up to which the forward pass must be computed
-
const Tensor &dynet::ComputationGraph
incremental_forward
(const Expression &last)¶ Run forward pass from the last computed node to given one.
Useful if you want to add nodes and evaluate just the new parts.
- Return
- Value of the
last
Expression after execution - Parameters
last
: Expression up to which the forward pass must be computed
-
const Tensor &dynet::ComputationGraph
incremental_forward
(VariableIndex i)¶ Run forward pass from the last computed node to given one.
Useful if you want to add nodes and evaluate just the new parts.
- Return
- Value of the end Node after execution
- Parameters
last
: Variable index of the node up to which the forward pass must be computed
-
const Tensor &dynet::ComputationGraph
get_value
(VariableIndex i)¶ Get forward value for node at index i.
Performs forward evaluation if note available (may compute more than strictly what is needed).
- Return
- Requested value
- Parameters
i
: Index of the variable from which you want the value
-
const Tensor &dynet::ComputationGraph
get_value
(const Expression &e)¶ Get forward value for the given expression.
Performs forward evaluation if note available (may compute more than strictly what is needed).
- Return
- Requested value
- Parameters
e
: Expression from which you want the value
-
const Tensor &dynet::ComputationGraph
get_gradient
(VariableIndex i)¶ Get gradient for node at index i.
Performs backward pass if not available (may compute more than strictly what is needed).
- Return
- Requested gradient
- Parameters
i
: Index of the variable from which you want the gradient
-
const Tensor &dynet::ComputationGraph
get_gradient
(const Expression &e)¶ Get forward gradient for the given expression.
Performs backward pass if not available (may compute more than strictly what is needed).
- Return
- Requested gradient
- Parameters
e
: Expression from which you want the gradient
-
void dynet::ComputationGraph
invalidate
()¶ Clears forward caches (for get_value etc).
-
void dynet::ComputationGraph
backward
(const Expression &last, bool full = false)¶ Computes backward gradients from the front-most evaluated node.
The parameter
full
specifies whether the gradients should be computed for all nodes (true
) or only non-constant nodes.By default, a node is constant unless
- it is a parameter node
- it depends on a non-constant node
Thus, functions of constants and inputs are considered as constants.
Turn
full
on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.- Parameters
last
: Expression from which to compute the gradientfull
: Whether to compute all gradients (including with respect to constant nodes).
-
void dynet::ComputationGraph
backward
(VariableIndex i, bool full = false)¶ Computes backward gradients from node i (assuming it already been evaluated).
The parameter
full
specifies whether the gradients should be computed for all nodes (true
) or only non-constant nodes.By default, a node is constant unless
- it is a parameter node
- it depends on a non-constant node
Thus, functions of constants and inputs are considered as constants.
Turn
full
on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.- Parameters
i
: Index of the node from which to compute the gradientfull
: Whether to compute all gradients (including with respect to constant nodes). Turn this on if you want to retrieve gradients w.r.t. inputs for instance. By default this is turned off, so that the backward pass ignores nodes which have no influence on gradients w.r.t. parameters for efficiency.
-
void dynet::ComputationGraph
print_graphviz
() const¶ Used for debugging.
-
unsigned dynet::ComputationGraph
get_id
() const¶ Get the unique graph ID.
This ID is incremented by 1 each time a computation graph is created
- Return
- graph is
-
dynet::ComputationGraph
Nodes¶
Nodes are constituents of the computation graph. The end user doesn’t interact with Nodes but with Expressions.
However implementing new operations requires to create a new subclass of the Node class described below.
-
struct dynet
::
Node
¶ - #include <dynet.h>
Represents an SSA variable.
Contains information on tha computation node : arguments, output value and gradient of the output with respect to the function. This class must be inherited to implement any new operation. See nodes.cc for examples. An operation on expressions can then be created from the new Node, see expr.h/expr.cc for examples
Subclassed by dynet::Abs, dynet::Acos, dynet::Acosh, dynet::AddVectorToAllColumns, dynet::AffineTransform, dynet::Asin, dynet::Asinh, dynet::Atan, dynet::Atanh, dynet::Average, dynet::AverageColumns, dynet::BinaryLogLoss, dynet::BlockDropout, dynet::CircularConvolution, dynet::CircularCorrelation, dynet::Concatenate, dynet::ConcatenateToBatch, dynet::Constant, dynet::ConstantMinusX, dynet::ConstantPlusX, dynet::ConstParameterNode, dynet::ConstrainedSoftmax, dynet::ConstScalarMultiply, dynet::Conv2D, dynet::Cos, dynet::Cosh, dynet::Cube, dynet::CwiseMultiply, dynet::CwiseQuotient, dynet::CwiseSum, dynet::DotProduct, dynet::Dropout, dynet::DropoutBatch, dynet::DropoutDim, dynet::Erf, dynet::Exp, dynet::ExponentialLinearUnit, dynet::Filter1DNarrow, dynet::FlipGradient, dynet::FoldRows, dynet::GaussianNoise, dynet::Hinge, dynet::HingeDim, dynet::HuberDistance, dynet::Identity, dynet::InnerProduct3D_1D, dynet::InnerProduct3D_1D_1D, dynet::InputNode, dynet::KMaxPooling, dynet::KMHNGram, dynet::L1Distance, dynet::L2Norm, dynet::Log, dynet::LogDet, dynet::LogGamma, dynet::LogisticSigmoid, dynet::LogSoftmax, dynet::LogSumExp, dynet::LogSumExpDimension, dynet::MatrixInverse, dynet::MatrixMultiply, dynet::Max, dynet::MaxDimension, dynet::MaxPooling1D, dynet::MaxPooling2D, dynet::Min, dynet::MinDimension, dynet::MomentBatches, dynet::MomentDimension, dynet::MomentElements, dynet::Negate, dynet::NoBackprop, dynet::PairwiseRankLoss, dynet::ParameterNodeBase, dynet::PickBatchElements, dynet::PickElement, dynet::PickNegLogSoftmax, dynet::PickRange, dynet::PoissonRegressionLoss, dynet::Pow, dynet::RandomBernoulli, dynet::RandomGumbel, dynet::RandomNormal, dynet::RandomUniform, dynet::Rectify, dynet::Reshape, dynet::RestrictedLogSoftmax, dynet::ScalarInputNode, dynet::SelectCols, dynet::SelectRows, dynet::SigmoidLinearUnit, dynet::Sin, dynet::Sinh, dynet::Softmax, dynet::SoftSign, dynet::SparseInputNode, dynet::Sparsemax, dynet::SparsemaxLoss, dynet::Sqrt, dynet::Square, dynet::SquaredEuclideanDistance, dynet::SquaredNorm, dynet::StdBatches, dynet::StdDimension, dynet::StdElements, dynet::StridedSelect, dynet::Sum, dynet::SumDimension, dynet::SumElements, dynet::Tan, dynet::Tanh, dynet::ToDevice, dynet::TraceOfProduct, dynet::Transpose, dynet::VanillaLSTMC, dynet::VanillaLSTMGates, dynet::VanillaLSTMH, dynet::WeightNormalization
Public Functions
-
virtual Dim dynet::Node
dim_forward
(const std::vector<Dim> &xs) const = 0¶ Compute dimensions of result for given dimensions of inputs.
Also checks to make sure inputs are compatible with each other
- Return
- Dimension of the output
- Parameters
xs
: Vector containing the dimensions of the inputs
-
virtual std::string dynet::Node
as_string
(const std::vector<std::string> &args) const = 0¶ Returns important information for debugging.
See nodes-conv.cc for examples
- Return
- String description of the node
- Parameters
args
: String descriptions of the arguments
-
size_t dynet::Node
aux_storage_size
() const¶ Size of the auxiliar storage.
in general, this will return an empty size, but if a component needs to store extra information in the forward pass for use in the backward pass, it can request the memory here (nb. you could put it on the Node object, but in general, edges should not allocate tensor memory since memory is managed centrally for the entire computation graph).
- Return
- Size
-
virtual void dynet::Node
forward_impl
(const std::vector<const Tensor *> &xs, Tensor &fx) const = 0¶ Forward computation.
This function contains the logic for the forward pass. Some implementation remarks from nodes.cc:
- fx can be understood as a pointer to the (preallocated) location for the result of forward to be stored
- fx is not initialized, so after calling forward fx must point to the correct answer
- fx can be repointed to an input, if forward(x) evaluates to x (e.g., in reshaping)
- scalars results of forward are placed in fx.v[0]
- DYNET manages its own memory, not Eigen, and it is configured with the EIGEN_NO_MALLOC option. If you get an error about Eigen attempting to allocate memory, it is (probably) because of an implicit creation of a temporary variable. To tell Eigen this is not necessary, the noalias() method is available. If you really do need a temporary variable, its capacity must be requested by Node::aux_storage_size
Note on debugging problems with differentiable components
- fx is uninitialized when forward is called- are you relying on it being 0?
- Parameters
xs
: Pointers to the inputsfx
: pointer to the (preallocated) location for the result of forward to be stored
-
virtual void dynet::Node
backward_impl
(const std::vector<const Tensor *> &xs, const Tensor &fx, const Tensor &dEdf, unsigned i, Tensor &dEdxi) const = 0¶ Accumulates the derivative of E with respect to the ith argument to f, that is, xs[i].
This function contains the logic for the backward pass. Some implementation remarks from nodes.cc:
- dEdxi MUST ACCUMULATE a result since multiple calls to forward may depend on the same x_i. Even, e.g., Identity must be implemented as dEdx1 += dEdf. THIS IS EXTREMELY IMPORTANT
- scalars results of forward are placed in fx.v[0]
- DYNET manages its own memory, not Eigen, and it is configured with the EIGEN_NO_MALLOC option. If you get an error about Eigen attempting to allocate memory, it is (probably) because of an implicit creation of a temporary variable. To tell Eigen this is not necessary, the noalias() method is available. If you really do need a temporary variable, its capacity must be requested by Node::aux_storage_size
Note on debugging problems with differentiable components
- dEdxi must accummulate (see point 4 above!)
- Parameters
xs
: Pointers to inputsfx
: OutputdEdf
: Gradient of the objective w.r.t the output of the nodei
: Index of the input w.r.t which we take the derivativedEdxi
: Gradient of the objective w.r.t the input of the node
-
virtual bool dynet::Node
supports_multibatch
() const¶ Whether this node supports computing multiple batches in one call.
If true, forward and backward will be called once with a multi-batch tensor. If false, forward and backward will be called multiple times for each item.
- Return
- Support for multibatch
-
virtual bool dynet::Node
supports_multidevice
() const¶ Whether this node supports processing inputs/outputs on multiple devices.
DyNet will throw an error if you try to process inputs and outputs on different devices unless this is activated.
- Return
- Support for multi-device
-
void dynet::Node
forward
(const std::vector<const Tensor *> &xs, Tensor &fx) const¶ perform the forward/backward passes in one or multiple calls
- Parameters
xs
: Pointers to the inputsfx
: pointer to the (preallocated) location for the result of forward to be stored
-
void dynet::Node
backward
(const std::vector<const Tensor *> &xs, const Tensor &fx, const Tensor &dEdf, unsigned i, Tensor &dEdxi) const¶ perform the backward passes in one or multiple calls
- Parameters
xs
: Pointers to inputsfx
: OutputdEdf
: Gradient of the objective w.r.t the output of the nodei
: Index of the input w.r.t which we take the derivativedEdxi
: Gradient of the objective w.r.t the input of the node
-
virtual int dynet::Node
autobatch_sig
(const ComputationGraph &cg, SigMap &sm) const¶ signature for automatic batching This will be equal only for nodes that can be combined. Returns 0 for unbatchable functions.
-
virtual std::vector<int> dynet::Node
autobatch_concat
(const ComputationGraph &cg) const¶ which inputs can be batched This will be true for inputs that should be concatenated when autobatching, and false for inputs that should be shared among all batches.
-
virtual Node *dynet::Node
autobatch_pseudo_node
(const ComputationGraph &cg, const std::vector<VariableIndex> &batch_ids) const¶ create a pseudonode for autobatching This will combine together multiple nodes into one big node for the automatic batching functionality. When a node representing one component of the mini-batch can be used as-is it is OK to just return the null pointer, otherwise we should make the appropriate changes and return a new node.
-
virtual void dynet::Node
autobatch_reshape
(const ComputationGraph &cg, const std::vector<VariableIndex> &batch_ids, const std::vector<int> &concat, std::vector<const Tensor *> &xs, Tensor &fx) const¶ reshape the tensors for auto Takes in info, and reshapes the dimensions of xs (for which “concat” is true), and fx. By default do no reshaping, which is OK for componentwise operations.
-
void dynet::Node
autobatch_reshape_concatonly
(const ComputationGraph &cg, const std::vector<VariableIndex> &batch_ids, const std::vector<int> &concat, std::vector<const Tensor *> &xs, Tensor &fx) const¶ reshape the tensors for auto Takes in info, and reshapes the dimensions of xs (for which “concat” is true) and fx by concatenating their batches.
Public Members
-
void *dynet::Node
aux_mem
¶ this will usually be null. but, if your node needs to store intermediate values between forward and backward, you can use store it here. request the number of bytes you need from aux_storage_size(). Note: this memory will be on the CPU or GPU, depending on your computation backend
-
virtual Dim dynet::Node
Parameters and Model¶
Parameters are things that are optimized. in contrast to a system like Torch where computational modules may have their own parameters, in DyNet parameters are just parameters.
To deal with sparse updates, there are two parameter classes:
- Parameters represents a vector, matrix, (eventually higher order tensors) of parameters. These are densely updated.
- LookupParameters represents a table of vectors that are used to embed a set of discrete objects. These are sparsely updated.
-
struct dynet
::
ParameterStorageBase
¶ - #include <model.h>
This is the base class for ParameterStorage and LookupParameterStorage, the objects handling the actual parameters.
You can access the storage from any Parameter (resp. LookupParameter) class, use it only to do low level manipulations.
Subclassed by dynet::LookupParameterStorage, dynet::ParameterStorage
Public Functions
-
virtual void dynet::ParameterStorageBase
scale_parameters
(float a) = 0¶ Scale the parameters.
- Parameters
a
: scale factor
-
virtual void dynet::ParameterStorageBase
scale_gradient
(float a) = 0¶ Scale the gradient.
- Parameters
a
: scale factor
-
virtual void dynet::ParameterStorageBase
zero
() = 0¶ Set the parameters to 0.
-
virtual void dynet::ParameterStorageBase
squared_l2norm
(float *sqnorm) const = 0¶ Get the parameter squared l2 norm.
- Parameters
sqnorm
: Pointer to the float holding the result
-
virtual void dynet::ParameterStorageBase
g_squared_l2norm
(float *sqnorm) const = 0¶ Get the squared l2 norm of the gradient w.r.t. these parameters.
- Parameters
sqnorm
: Pointer to the float holding the result
-
virtual bool dynet::ParameterStorageBase
is_updated
() const = 0¶ Check whether corpus is updated.
-
virtual bool dynet::ParameterStorageBase
has_grad
() const = 0¶ Check whether the gradient is zero or not (true if gradient is non-zero)
-
virtual size_t dynet::ParameterStorageBase
size
() const = 0¶ Get the size (number of scalar parameters)
- Return
- Number of scalar parameters
-
virtual void dynet::ParameterStorageBase
-
struct dynet
::
ParameterStorage
¶ - #include <model.h>
Storage class for Parameters.
Inherits from dynet::ParameterStorageBase
Subclassed by dynet::ParameterStorageCreator
Public Functions
-
void dynet::ParameterStorage
copy
(const ParameterStorage &val)¶ Copy from another ParameterStorage.
- Parameters
val
: ParameterStorage to copy from
-
void dynet::ParameterStorage
accumulate_grad
(const Tensor &g)¶ Add a tensor to the gradient.
After this method gets called, g <- g + d
- Parameters
g
: Tensor to add
-
void dynet::ParameterStorage
clear
()¶ Clear the gradient (set it to 0)
-
void dynet::ParameterStorage
clip
(float left, float right)¶ Clip the values to the range [left, right].
Public Members
-
std::string dynet::ParameterStorage
name
¶ Name of this parameter
-
Dim dynet::ParameterStorage
dim
¶ Dimensions of the parameter tensor
-
Tensor dynet::ParameterStorage
values
¶ Values of the parameter
-
Tensor dynet::ParameterStorage
g
¶ Values of the gradient w.r.t. this parameter
-
bool dynet::ParameterStorage
updated
¶ Whether this is updated
-
bool dynet::ParameterStorage
nonzero_grad
¶ Whether the gradient is zero
-
ParameterCollection *dynet::ParameterStorage
owner
¶ Pointer to the collection that “owns” this parameter
-
void dynet::ParameterStorage
-
struct dynet
::
LookupParameterStorage
¶ - #include <model.h>
Storage class for LookupParameters.
Inherits from dynet::ParameterStorageBase
Subclassed by dynet::LookupParameterStorageCreator
Public Functions
-
void dynet::LookupParameterStorage
initialize
(unsigned index, const std::vector<float> &val)¶ Initialize one particular lookup.
- Parameters
index
: Index of the lookput to initializeval
: Values
-
void dynet::LookupParameterStorage
copy
(const LookupParameterStorage &val)¶ Copy from another LookupParameterStorage.
- Parameters
val
: Other LookupParameterStorage to copy from
-
void dynet::LookupParameterStorage
accumulate_grad
(const Tensor &g)¶ Add a Tensor to the gradient of the whole lookup matrix.
after this
grads<-grads + g
- Parameters
g
: [description]
-
void dynet::LookupParameterStorage
accumulate_grad
(unsigned index, const Tensor &g)¶ Add a Tensor to the gradient of one of the lookups.
after this
grads[index]<-grads[index] + g
- Parameters
index
: [description]g
: [description]
-
void dynet::LookupParameterStorage
accumulate_grads
(unsigned n, const unsigned *ids_host, const unsigned *ids_dev, float *g)¶ Add tensors to muliple lookups.
After this method gets called,
grads[ids_host[i]] <- grads[ids_host[i]] + g[i*dim.size():(i+1)*dim.size()]
- Parameters
n
: size ofids_host
ids_host
: Indices of the gradients to updateids_dev
: [To be documented] (only for GPU)g
: Values
Public Members
-
std::string dynet::LookupParameterStorage
name
¶ Name of this parameter
-
Dim dynet::LookupParameterStorage
all_dim
¶ Total dimension
-
Tensor dynet::LookupParameterStorage
all_values
¶ Values for all dimensions at once
-
Tensor dynet::LookupParameterStorage
all_grads
¶ Gradient values for all dimensions at once
-
Dim dynet::LookupParameterStorage
dim
¶ Dimension for one lookup
-
std::vector<Tensor> dynet::LookupParameterStorage
values
¶ List of values for each lookup
-
std::vector<Tensor> dynet::LookupParameterStorage
grads
¶ List of gradient values for each lookup
-
std::unordered_set<unsigned> dynet::LookupParameterStorage
non_zero_grads
¶ Gradients are sparse, so track which components are nonzero
-
bool dynet::LookupParameterStorage
updated
¶ Whether this lookup parameter should be updated
-
bool dynet::LookupParameterStorage
nonzero_grad
¶ Whether all of the gradients have been updated. Whether the gradient is zero
-
ParameterCollection *dynet::LookupParameterStorage
owner
¶ Pointer to the collection that “owns” this parameter
-
void dynet::LookupParameterStorage
-
struct dynet
::
Parameter
¶ - #include <model.h>
Object representing a trainable parameter.
This objects acts as a high level component linking the actual parameter values (ParameterStorage) and the ParameterCollection. As long as you don’t want to do low level hacks at the ParameterStorage level, this is what you will use.
Public Functions
Constructor.
This is called by the model, you shouldn’t need to use it
- Parameters
p
: Shared pointer to the parameter storage
-
ParameterStorage &dynet::Parameter
get_storage
() const¶ Get underlying ParameterStorage object.
- Return
- ParameterStorage holding the parameter values
-
string dynet::Parameter
get_fullname
() const¶ Get the full name of the ParameterStorage object.
-
float dynet::Parameter
current_weight_decay
() const¶ Get the current weight decay for the parameters.
-
void dynet::Parameter
scale_gradient
(float s)¶ Scales the gradient (multiplies by
s
)- Parameters
s
: scale
Public Members
-
std::shared_ptr<ParameterStorage> dynet::Parameter
p
¶ Pointer to the storage for this Parameter
-
struct dynet
::
LookupParameter
¶ - #include <model.h>
Object representing a trainable lookup parameter.
Public Functions
-
LookupParameterStorage &dynet::LookupParameter
get_storage
() const¶ Get underlying LookupParameterStorage object.
- Return
- LookupParameterStorage holding the parameter values
-
void dynet::LookupParameter
initialize
(unsigned index, const std::vector<float> &val) const¶ Initialize one particular column.
- Parameters
index
: Index of the column to be initializedval
: [description]
-
void dynet::LookupParameter
zero
()¶ Zero the parameters.
-
string dynet::LookupParameter
get_fullname
() const¶ Get the full name of the ParameterStorage object.
-
Dim dynet::LookupParameter
dim
() const¶ Shape of the lookup parameter.
- Return
- Shape as a
Dim
object
-
std::vector<Tensor> *dynet::LookupParameter
values
()¶ Values of the lookup parameter.
- Return
- Values as a
Tensor
object
-
float dynet::LookupParameter
current_weight_decay
() const¶ Get the current weight decay for the parameters.
-
void dynet::LookupParameter
scale
(float s)¶ Scales the parameter (multiplies by
s
)- Parameters
s
: scale
-
void dynet::LookupParameter
scale_gradient
(float s)¶ Scales the gradient (multiplies by
s
)- Parameters
s
: scale
-
void dynet::LookupParameter
set_updated
(bool b)¶ Set the parameter as updated.
- Parameters
b
: Update status
-
bool dynet::LookupParameter
is_updated
()¶ Check the update status.
- Return
- Update status
Public Members
-
std::shared_ptr<LookupParameterStorage> dynet::LookupParameter
p
¶ Pointer to the storage for this Parameter
-
LookupParameterStorage &dynet::LookupParameter
-
class dynet
::
ParameterCollection
¶ - #include <model.h>
This is a collection of parameters.
if you need a matrix of parameters, or a lookup table - ask an instance of this class. This knows how to serialize itself. Parameters know how to track their gradients, but any extra information (like velocity) will live here
Subclassed by dynet::Model
Public Functions
-
dynet::ParameterCollection
ParameterCollection
()¶ Constructor.
-
float dynet::ParameterCollection
gradient_l2_norm
() const¶ Returns the l2 of your gradient.
Use this to look for gradient vanishing/exploding
- Return
- L2 norm of the gradient
-
void dynet::ParameterCollection
reset_gradient
()¶ Sets all gradients to zero.
-
Parameter dynet::ParameterCollection
add_parameters
(const Dim &d, float scale = 0.0f, const std::string &name = "", Device *device = dynet::default_device)¶ Add parameters to model and returns Parameter object.
creates a ParameterStorage object holding a tensor of dimension
d
and returns a Parameter object (to be used as input in the computation graph). The coefficients are sampled according to thescale
parameter- Return
- Parameter object to be used in the computation graph
- Parameters
d
: Shape of the parameterscale
: If scale is non-zero, initializes according to \(mathcal U([-\mathrm{scale},+\mathrm{scale}]\), otherwise uses Glorot initializationname
: Name of the parameterdevice
: Device placement for the parameter
-
Parameter dynet::ParameterCollection
add_parameters
(const Dim &d, Device *device)¶ Add parameters to model and returns Parameter object.
creates a ParameterStorage object holding a tensor of dimension
d
and returns a Parameter object (to be used as input in the computation graph).- Return
- Parameter object to be used in the computation graph
- Parameters
d
: Shape of the parameterdevice
: Device placement for the parameter
-
Parameter dynet::ParameterCollection
add_parameters
(const Dim &d, const std::string &name, Device *device = dynet::default_device)¶ Add parameters to model and returns Parameter object.
creates a ParameterStorage object holding a tensor of dimension
d
and returns a Parameter object (to be used as input in the computation graph).- Return
- Parameter object to be used in the computation graph
- Parameters
d
: Shape of the parametername
: Name of the parameterdevice
: Device placement for the parameter
-
Parameter dynet::ParameterCollection
add_parameters
(const Dim &d, const ParameterInit &init, const std::string &name = "", Device *device = dynet::default_device)¶ Add parameters with custom initializer.
- Return
- Parameter object to be used in the computation graph
- Parameters
d
: Shape of the parameterinit
: Custom initializername
: Name of the parameterdevice
: Device placement for the parameter
-
std::vector<std::shared_ptr<ParameterStorageBase>> dynet::ParameterCollection
get_parameter_storages_base
() const¶ Get parameters base in current model.
- Return
- list of points to ParameterStorageBase objects
-
std::shared_ptr<ParameterStorage> dynet::ParameterCollection
get_parameter_storage
(const std::string &pname)¶ Get parameter in current model.
It is not recommended to use this
- Return
- the pointer to the Parameter object
-
std::vector<std::shared_ptr<ParameterStorage>> dynet::ParameterCollection
get_parameter_storages
() const¶ Get parameters in current model.
- Return
- list of points to ParameterStorage objects
-
LookupParameter dynet::ParameterCollection
add_lookup_parameters
(unsigned n, const Dim &d, const std::string &name = "", Device *device = dynet::default_device)¶ Add lookup parameter to model.
Same as add_parameters. Initializes with Glorot
- Return
- LookupParameter object to be used in the computation graph
- Parameters
n
: Number of lookup indicesd
: Dimension of each embeddingname
: Name of the parameterdevice
: Device placement for the parameter
-
LookupParameter dynet::ParameterCollection
add_lookup_parameters
(unsigned n, const Dim &d, const ParameterInit &init, const std::string &name = "", Device *device = dynet::default_device)¶ Add lookup parameter with custom initializer.
- Return
- LookupParameter object to be used in the computation graph
- Parameters
n
: Number of lookup indicesd
: Dimension of each embeddinginit
: Custom initializername
: Name of the parameterdevice
: Device placement for the parameter
-
std::shared_ptr<LookupParameterStorage> dynet::ParameterCollection
get_lookup_parameter_storage
(const std::string &lookup_pname)¶ Get lookup parameter in current model.
It is not recommended to use this
- Return
- the pointer to the LookupParameter object
-
std::vector<std::shared_ptr<LookupParameterStorage>> dynet::ParameterCollection
get_lookup_parameter_storages
() const¶ Get lookup parameters in current model.
- Return
- list of points to LookupParameterStorage objects
-
void dynet::ParameterCollection
project_weights
(float radius = 1.0f)¶ project weights so their L2 norm = radius
NOTE (Paul) : I am not sure this is doing anything currently. The argument doesn’t seem to be used anywhere… If you need this raise an issue on github
- Parameters
radius
: Target norm
-
void dynet::ParameterCollection
set_weight_decay_lambda
(float lambda)¶ Set the weight decay coefficient.
- Parameters
lambda
: Weight decay coefficient
-
const std::vector<std::shared_ptr<ParameterStorage>> &dynet::ParameterCollection
parameters_list
() const¶ Returns list of shared pointers to ParameterSorages.
You shouldn’t need to use this
- Return
- List of shared pointers to ParameterSorages
-
const std::vector<std::shared_ptr<LookupParameterStorage>> &dynet::ParameterCollection
lookup_parameters_list
() const¶ Returns list of pointers to LookupParameterSorages.
You shouldn’t need to use this
- Return
- List of pointers to LookupParameterSorages
-
size_t dynet::ParameterCollection
parameter_count
() const¶ Returns the total number of tunable parameters (i. e. scalars) contained within this model.
That is to say, a 2x2 matrix counts as four parameters.
- Return
- Number of parameters
-
size_t dynet::ParameterCollection
updated_parameter_count
() const¶ Returns total number of (scalar) parameters updated.
- Return
- number of updated parameters
-
void dynet::ParameterCollection
set_updated_param
(const Parameter *p, bool status)¶ [brief description]
[long description]
- Parameters
p
: [description]status
: [description]
-
void dynet::ParameterCollection
set_updated_lookup_param
(const LookupParameter *p, bool status)¶ [brief description]
[long description]
- Parameters
p
: [description]status
: [description]
-
bool dynet::ParameterCollection
is_updated_param
(const Parameter *p)¶ [brief description]
[long description]
- Return
- [description]
- Parameters
p
: [description]
-
bool dynet::ParameterCollection
is_updated_lookup_param
(const LookupParameter *p)¶ [brief description]
[long description]
- Return
- [description]
- Parameters
p
: [description]
-
ParameterCollection dynet::ParameterCollection
add_subcollection
(const std::string &name = "")¶ Add a sub-collection.
This will allow you to add a ParameterCollection that is a (possibly named) subset of the original collection. This is useful if you want to save/load/update only part of the parameters in the model.
- Return
- The subcollection
-
size_t dynet::ParameterCollection
size
()¶ Get size.
Get the number of parameters in the ParameterCollection
-
std::string dynet::ParameterCollection
get_fullname
() const¶ get namespace of current ParameterCollection object(end with a slash)
-
L2WeightDecay &dynet::ParameterCollection
get_weight_decay
()¶ Get the weight decay object.
-
dynet::ParameterCollection
-
struct dynet
::
ParameterInit
¶ - #include <param-init.h>
Initializers for parameters.
Allows for custom parameter initialization
Subclassed by dynet::ParameterInitConst, dynet::ParameterInitFromFile, dynet::ParameterInitFromVector, dynet::ParameterInitGlorot, dynet::ParameterInitIdentity, dynet::ParameterInitNormal, dynet::ParameterInitSaxe, dynet::ParameterInitUniform
Public Functions
-
dynet::ParameterInit
ParameterInit
()¶ Default constructor.
-
virtual void dynet::ParameterInit
initialize_params
(Tensor &values) const = 0¶ Function called upon initialization.
Whenever you inherit this struct to implement your own custom initializer, this is the function you want to overload to implement your logic.
- Parameters
values
: The tensor to be initialized. You should modify it in-place. See dynet/model.cc for some examples
-
dynet::ParameterInit
-
struct dynet
::
ParameterInitNormal
¶ - #include <param-init.h>
Initialize parameters with samples from a normal distribution.
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitNormal
ParameterInitNormal
(float m = 0.0f, float v = 1.0f)¶ Constructor.
- Parameters
m
: Mean of the gaussian distributionv
: Variance of the gaussian distribution (reminder : the variance is the square of the standard deviation)
-
dynet::ParameterInitNormal
-
struct dynet
::
ParameterInitUniform
¶ - #include <param-init.h>
Initialize parameters with samples from a uniform distribution.
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitUniform
ParameterInitUniform
(float scale)¶ Constructor for uniform distribution centered on 0.
[long description]Samples parameters from \(mathcal U([-\mathrm{scale},+\mathrm{scale}]\)
- Parameters
scale
: Scale of the distribution
-
dynet::ParameterInitUniform
ParameterInitUniform
(float l, float r)¶ Constructor for uniform distribution in a specific interval.
[long description]
- Parameters
l
: Lower bound of the intervalr
: Upper bound of the interval
-
dynet::ParameterInitUniform
-
struct dynet
::
ParameterInitConst
¶ - #include <param-init.h>
Initialize parameters with a constant value.
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitConst
ParameterInitConst
(float c)¶ Constructor.
- Parameters
c
: Constant value
-
dynet::ParameterInitConst
-
struct dynet
::
ParameterInitIdentity
¶ - #include <param-init.h>
Initialize as the identity.
This will raise an exception if used on non square matrices
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitIdentity
ParameterInitIdentity
()¶ Constructor.
-
dynet::ParameterInitIdentity
-
struct dynet
::
ParameterInitGlorot
¶ - #include <param-init.h>
Initialize with the methods described in Glorot, 2010
In order to preserve the variance of the forward and backward flow across layers, the parameters \(\theta\) are initialized such that \(\mathrm{Var}(\theta)=\frac 2 {n_1+n_2}\) where \(n_1,n_2\) are the input and output dim. Important note : The underlying distribution is uniform (not gaussian)
Note: This is also known as Xavier initialization
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitGlorot
ParameterInitGlorot
(bool is_lookup = false, float gain = 1.f)¶ Constructor.
- Parameters
is_lookup
: Boolean value identifying the parameter as a LookupParametergain
: Scaling parameter. In order for the Glorot initialization to be correct, you should ût this equal to \(\frac 1 {f'(0)}\) where \(f\) is your activation function
-
dynet::ParameterInitGlorot
-
struct dynet
::
ParameterInitSaxe
¶ - #include <param-init.h>
Initializes according to Saxe et al., 2014
Initializes as a random orthogonal matrix (unimplemented for GPU)
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitSaxe
ParameterInitSaxe
(float gain = 1.0)¶ Constructor.
-
dynet::ParameterInitSaxe
-
struct dynet
::
ParameterInitFromFile
¶ - #include <param-init.h>
Initializes from a file.
Useful for reusing weights, etc…
Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitFromFile
ParameterInitFromFile
(std::string f)¶ Constructor.
- Parameters
f
: File name (format should just be a list of values)
-
dynet::ParameterInitFromFile
-
struct dynet
::
ParameterInitFromVector
¶ - #include <param-init.h>
Initializes from a
std::vector
of floats.Inherits from dynet::ParameterInit
Public Functions
-
dynet::ParameterInitFromVector
ParameterInitFromVector
(std::vector<float> v)¶ Constructor.
- Parameters
v
: Vector of values to be used
-
dynet::ParameterInitFromVector
Tensor¶
Tensor objects provide a bridge between C++ data structures and Eigen Tensors for multidimensional data.
Concretely, as an end user you will obtain a tensor object after calling .value()
on an expression. You can then use functions described below to convert these tensors to float
s, arrays of float
s, to save and load the values, etc…
Conversely, when implementing low level nodes (e.g. for new operations), you will need to retrieve Eigen tensors from dynet tensors in order to perform efficient computation.
-
vector<Eigen::DenseIndex> dynet
::
as_vector
(const IndexTensor &v)¶ Get the array of indices in an index tensor.
For higher order tensors this returns the flattened value
- Return
- Index values
- Parameters
v
: Input index tensor
-
std::ostream &dynet
::
operator<<
(std::ostream &os, const Tensor &t)¶ You can use
cout<<tensor;
for debugging or saving.- Parameters
os
: output streamt
: Tensor
-
real dynet
::
as_scalar
(const Tensor &t)¶ Get a scalar value from an order 0 tensor.
Throws an
runtime_error
exception if the tensor has more than one element.TODO : Change for custom invalid dimension exception maybe?
- Return
- Scalar value
- Parameters
t
: Input tensor
-
std::vector<real> dynet
::
as_vector
(const Tensor &v)¶ Get the array of values in the tensor.
For higher order tensors this returns the flattened value
- Return
- Values
- Parameters
v
: Input tensor
-
real dynet
::
rand01
()¶ This is a helper function to sample uniformly in \([0,1]\).
- Return
- \(x\sim\mathcal U([0,1])\)
-
int dynet
::
rand0n
(int n)¶ This is a helper function to sample uniformly in \(\{0,\dots,n-1\}\).
- Return
- \(x\sim\mathcal U(\{0,\dots,n-1\})\)
- Parameters
n
: Upper bound (excluded)
-
real dynet
::
rand_normal
()¶ This is a helper function to sample from a normalized gaussian distribution.
- Return
- \(x\sim\mathcal N(0,1)\)
-
struct dynet
::
IndexTensor
¶ - #include <index-tensor.h>
Represents a tensor of indices.
This holds indices to locations within a dimension or tensor.
Public Functions
-
dynet::IndexTensor
IndexTensor
()¶ Create an empty tensor.
-
dynet::IndexTensor
IndexTensor
(const Dim &d, Eigen::DenseIndex *v, Device *dev, DeviceMempool mem)¶ Creates a tensor.
[long description]
- Parameters
d
: Shape of the tensorv
: Pointer to the valuesdev
: Devicemem
: Memory pool
Public Members
-
Dim dynet::IndexTensor
d
¶ Shape of tensor
-
Eigen::DenseIndex *dynet::IndexTensor
v
¶ Pointer to memory
-
dynet::IndexTensor
-
struct dynet
::
Tensor
¶ - #include <tensor.h>
Represents a tensor of any order.
This provides a bridge between classic C++ types and Eigen tensors.
Public Functions
-
dynet::Tensor
Tensor
(const Dim &d, float *v, Device *dev, DeviceMempool mem)¶ Creates a tensor.
[long description]
- Parameters
d
: Shape of the tensorv
: Pointer to the valuesdev
: Devicemem
: Memory pool
-
float *dynet::Tensor
batch_ptr
(unsigned bid)¶ Get the pointer for a particular batch.
Automatically broadcasting if the size is zero
- Return
- Pointer to the memory where the batch values are located
- Parameters
bid
: Batch id requested
-
bool dynet::Tensor
is_valid
() const¶ Check for NaNs and infinite values.
This is very slow: use sparingly (it’s linear in the number of elements). This raises a
std::runtime_error
exception if the Tensor is on GPU because it’s not implemented yet- Return
- Whether the tensor contains any invalid value
-
Tensor dynet::Tensor
batch_elem
(unsigned b) const¶ Get a Tensor object representing a single batch.
If this tensor only has a single batch, then broadcast. Otherwise, check to make sure that the requested batch is smaller than the number of batches.
TODO: This is a bit wasteful, as it re-calculates
bs.batch_size()
every time.- Return
- Sub tensor at batch
b
- Parameters
b
: Batch id
-
dynet::Tensor
-
struct dynet
::
TensorTools
¶ - #include <tensor.h>
Provides tools for creating, accessing, copying and modifying tensors (in-place)
Public Static Functions
-
void dynet::TensorTools
clip
(Tensor &d, float left, float right)¶ Clip the values in the tensor to a fixed range.
- Parameters
d
: Tensor to modifyleft
: Target minimum valueright
: Target maximum value
-
void dynet::TensorTools
constant
(Tensor &d, float c)¶ Fills the tensor with a constant value.
- Parameters
d
: Tensor to modifyc
: Target value
-
void dynet::TensorTools
zero
(Tensor &d)¶ Fills a tensor with zeros.
- Parameters
d
: Input tensor
-
void dynet::TensorTools
identity
(Tensor &val)¶ Set the (order 2) tensor as the identity matrix.
this throws a runtime_error exception if the tensor isn’t a square matrix
- Parameters
val
: Input tensor
-
void dynet::TensorTools
randomize_bernoulli
(Tensor &val, real p, real scale = 1.0f)¶ Fill the tensor with bernoulli random variables and scale them by scale.
- Parameters
val
: Input tensorp
: Parameter of the bernoulli distributionscale
: Scale of the random variables
-
void dynet::TensorTools
randomize_normal
(Tensor &val, real mean = 0.0f, real stddev = 1.0f)¶ Fill the tensor with gaussian random variables.
- Parameters
val
: Input tensormean
: Meanstddev
: Standard deviation
-
void dynet::TensorTools
randomize_uniform
(Tensor &val, real left = 0.0f, real right = 1.0f)¶ Fill the tensor with uniform random variables.
- Parameters
val
: Input tensorleft
: Left bound of the intervalright
: Right bound of the interval
-
void dynet::TensorTools
randomize_orthonormal
(Tensor &val, real scale = 1.0f)¶ Takes a square matrix tensor and sets it as a random orthonormal matrix.
More specifically this samples a random matrix with RandomizeUniform and then performs SVD and returns the left orthonormal matrix in the decomposition, scaled by
scale
- Parameters
val
: Input tensorscale
: Value to which the resulting orthonormal matrix will be scaled
-
float dynet::TensorTools
access_element
(const Tensor &v, int index)¶ Access element of the tensor by index in the values array.
AccessElement and SetElement are very, very slow (potentially) - use appropriately
- Return
v.v[index]
- Parameters
v
: Tensorindex
: Index in the memory
-
float dynet::TensorTools
access_element
(const Tensor &v, const Dim &index)¶ Access element of the tensor by indices in the various dimension.
This only works for matrix shaped tensors (+ batch dimension). AccessElement and SetElement are very, very slow (potentially) - use appropriately
- Return
(*v)(index[0], index[1])
- Parameters
v
: Tensorindex
: Indices in the tensor
-
void dynet::TensorTools
set_element
(const Tensor &v, int index, float value)¶ Set element of the tensor by index in the values array.
AccessElement and SetElement are very, very slow (potentially) - use appropriately
- Parameters
v
: Tensorindex
: Index in the memoryvalue
: Desired value
-
void dynet::TensorTools
copy_element
(const Tensor &l, int lindex, Tensor &r, int rindex)¶ Copy element from one tensor to another (by index in the values array)
- Parameters
l
: Source tensorlindex
: Source indexr
: Target tensorrindex
: Target index
-
void dynet::TensorTools
set_elements
(const Tensor &v, const std::vector<float> &vec)¶ Set the elements of a tensor with an array of values.
(This uses memcpy so be careful)
- Parameters
v
: Input Tensorvec
: Values
-
void dynet::TensorTools
copy_elements
(Tensor &v, const Tensor &v_src)¶ Copy one tensor into another.
- Parameters
v
: Target tensorv_src
: Source tensor
-
void dynet::TensorTools
accumulate
(Tensor &v, const Tensor &v_src)¶ Accumulate the values of one tensor into another.
- Parameters
v
: Target tensorv_src
: Source tensor
-
void dynet::TensorTools
logsumexp
(const Tensor &x, Tensor &m, Tensor &z, unsigned d = 0)¶ Calculate the logsumexp function over all columns of the tensor.
- Parameters
x
: The input tensorm
: A tensor of scratch memory to hold the maximum values of each columnz
: The output tensor
-
IndexTensor dynet::TensorTools
argmax
(const Tensor &v, unsigned dim = 0, unsigned num = 1)¶ Calculate the index of the maximum value.
- Return
- A newly allocated LongTensor consisting of argmax IDs. The length of the dimension “dim” will be “num”, consisting of the appropriate IDs.
- Parameters
v
: A tensor where each row represents a probability distributiondim
: Which dimension to take the argmax overnum
: The number of kmax values
-
IndexTensor dynet::TensorTools
categorical_sample_log_prob
(const Tensor &v, unsigned dim = 0, unsigned num = 1)¶ Calculate samples from a log probability.
- Return
- A newly allocated LongTensor consisting of argmax IDs. The length of the dimension “dim” will be “num”, consisting of the appropriate IDs.
- Parameters
v
: A tensor where each row represents a log probability distributiondim
: Which dimension to take the sample overnum
: The number of samples for each row
-
void dynet::TensorTools
Dimensions¶
The Dim class holds information on the shape of a tensor. As explained in Unorthodox Design, in DyNet the dimensions are represented as the standard dimension + the batch dimension, which makes batched computation transparent.
-
DYNET_MAX_TENSOR_DIM
¶ Maximum number of dimensions supported by dynet : 7
-
struct dynet
::
Dim
¶ - #include <dim.h>
The Dim struct stores information about the dimensionality of expressions.
Batch dimension is treated separately from standard dimension.
Public Functions
-
dynet::Dim
Dim
(std::initializer_list<unsigned int> x)¶ Initialize from a list of dimensions.
The batch dimension is 1 in this case (non-batched expression)
- Parameters
x
: List of dimensions
-
dynet::Dim
Dim
(std::initializer_list<unsigned int> x, unsigned int b)¶ Initialize from a list of dimensions and a batch size.
- Parameters
x
: List of dimensionsb
: Batch size
-
dynet::Dim
Dim
(const std::vector<long> &x)¶ Initialize from a vector of dimensions.
The batch dimension is 1 in this case (non-batched expression)
- Parameters
x
: Array of dimensions
-
dynet::Dim
Dim
(const std::vector<long> &x, unsigned int b)¶ Initialize from a vector of dimensions and a batch size.
- Parameters
x
: Vector of dimensionsb
: Batch size
-
unsigned int dynet::Dim
batch_size
() const¶ Size of a batch (product of all dimensions)
- Return
- Size of a batch
-
unsigned int dynet::Dim
sum_dims
() const¶ Sum of all dimensions within a batch.
- Return
- Sum of the dimensions within a batch
-
Dim dynet::Dim
truncate
() const¶ remove trailing dimensions of 1
iterate all the dimensions of Dim, stop at last dimension of 1
- Return
- truncated dimension
-
Dim dynet::Dim
single_batch
() const¶ Set the batch dimension to 1.
- Return
- 1-batch version of this instance
-
void dynet::Dim
resize
(unsigned int i)¶ Change the number of dimensions.
- Parameters
int
: New number of dimensions
-
unsigned int dynet::Dim
rows
() const¶ Size of the first dimension.
- Return
- Size of the first dimension
-
unsigned int dynet::Dim
num_nonone_dims
() const¶ Number of non-one dimensions.
- Return
- Number of non-one dimensions
-
unsigned int dynet::Dim
cols
() const¶ Size of the second dimension (or 1 if only one dimension)
- Return
- Size of the second dimension (or 1 if only one dimension)
-
void dynet::Dim
set
(unsigned int i, unsigned int s)¶ Set specific dimension.
Set the value of a specific dimension to an arbitrary value
- Parameters
i
: Dimension indexs
: Dimension size
-
unsigned int dynet::Dim
operator[]
(unsigned int i) const¶ Access a specific dimension as you would access an array element.
- Return
- Size of dimension i
- Parameters
i
: Dimension index
-
unsigned int dynet::Dim
size
(unsigned int i) const¶ Size of dimension i.
- Return
- Size of dimension i
- Parameters
i
: Dimension index
-
void dynet::Dim
delete_dim
(unsigned int i)¶ Remove one of the dimensions.
- Parameters
i
: index of the dimension to be removed
-
void dynet::Dim
delete_dims
(std::vector<unsigned int> dims, bool reduce_batch)¶ Remove multi-dimensions.
- Parameters
dims
: dimensions to be removedreduce_batch
: reduce the batch dimension or not
-
void dynet::Dim
add_dim
(unsigned int n)¶ Insert a dimension to the end.
- Parameters
n
: the size of the new dimension
-
void dynet::Dim
insert_dim
(unsigned int i, unsigned int n)¶ Insert a dimension.
- Parameters
i
: the index to insert the new dimensionn
: the size of the new dimension
-
dynet::Dim
Operations¶
Operation Interface¶
The following functions define DyNet “Expressions,” which are used as an interface to the various functions that can be used to build DyNet computation graphs. Expressions for each specific function are listed below.
-
struct dynet
::
Expression
¶ - #include <expr.h>
Expressions are the building block of a Dynet computation graph.
[long description]
Public Functions
-
dynet::Expression
Expression
(ComputationGraph *pg, VariableIndex i)¶ Base expression constructor.
Used when creating operations
- Parameters
pg
: Pointer to the computation graphi
: Variable index
-
const Tensor &dynet::Expression
value
() const¶ Get value of the expression.
Throws a tuntime_error exception if no computation graph is available
- Return
- Value of the expression as a tensor
-
const Tensor &dynet::Expression
gradient
() const¶ Get gradient of the expression.
Throws a tuntime_error exception if no computation graph is available
Make sure to call
backward
on a downstream expression before calling this.If the expression is a constant expression (meaning it’s not a function of a parameter), dynet won’t compute it’s gradient for the sake of efficiency. You need to manually force the gradient computation by adding the agument
full=true
tobackward
- Return
- Value of the expression as a tensor
-
const Dim &dynet::Expression
dim
() const¶ Get dimension of the expression.
Throws a tuntime_error exception if no computation graph is available
- Return
- Dimension of the expression
-
dynet::Expression
Input Operations¶
These operations allow you to input something into the computation graph, either simple scalar/vector/matrix inputs from floats, or parameter inputs from a DyNet parameter object. They all requre passing a computation graph as input so you know which graph is being used for this particular calculation.
-
Expression dynet
::
input
(ComputationGraph &g, real s, Device *device = dynet::default_device)¶ Scalar input.
Create an expression that represents the scalar value s
- Return
- An expression representing s
- Parameters
g
: Computation graphs
: Real numberdevice
: The place device for the input value, default_device by default
-
Expression dynet
::
input
(ComputationGraph &g, const real *ps, Device *device = dynet::default_device)¶ Modifiable scalar input.
Create an expression that represents the scalar value *ps. If *ps is changed and the computation graph recalculated, the next forward pass will reflect the new value.
- Return
- An expression representing *ps
- Parameters
g
: Computation graphps
: Real number pointerdevice
: The place device for the input value, default_device by default
-
Expression dynet
::
input
(ComputationGraph &g, const Dim &d, const std::vector<float> &data, Device *device = dynet::default_device)¶ Vector/matrix/tensor input.
Create an expression that represents a vector, matrix, or tensor input. The dimensions of the input are defined by
d
. So for example >input(g,{50},data)
: will result in a 50-length vector >input(g,{50,30},data)
: will result in a 50x30 matrix and so on, for an arbitrary number of dimensions. This function can also be used to import minibatched inputs. For example, if we have 10 examples in a minibatch, each with size 50x30, then we call >input(g,Dim({50,30},10),data)
The data vector “data” will contain the values used to fill the input, in column-major format. The length must add to the product of all dimensions in d.- Return
- An expression representing data
- Parameters
g
: Computation graphd
: Dimension of the input matrixdata
: A vector of data pointsdevice
: The place device for the input value, default_device by default
-
Expression dynet
::
input
(ComputationGraph &g, const Dim &d, const std::vector<float> *pdata, Device *device = dynet::default_device)¶ Updatable vector/matrix/tensor input.
Similarly to input that takes a vector reference, input a vector, matrix, or tensor input. Because we pass the pointer, the data can be updated.
- Return
- An expression representing *pdata
- Parameters
g
: Computation graphd
: Dimension of the input matrixpdata
: A pointer to an (updatable) vector of data pointsdevice
: The place device for the input value, default_device by default
-
Expression dynet
::
input
(ComputationGraph &g, const Dim &d, const std::vector<unsigned int> &ids, const std::vector<float> &data, float defdata = 0.f, Device *device = dynet::default_device)¶ Sparse vector input.
This operation takes input as a sparse matrix of index/value pairs. It is exactly the same as the standard input via vector reference, but sets all non-specified values to “defdata” and resets all others to the appropriate input values.
- Return
- An expression representing data
- Parameters
g
: Computation graphd
: Dimension of the input matrixids
: The indexes of the data points to updatedata
: The data points corresponding to each indexdefdata
: The default data with which to set the unspecified data pointsdevice
: The place device for the input value, default_device by default
-
Expression dynet
::
parameter
(ComputationGraph &g, Parameter p)¶ Load parameter.
Load parameters into the computation graph.
- Return
- An expression representing p
- Parameters
g
: Computation graphp
: Parameter object to load
-
Expression dynet
::
parameter
(ComputationGraph &g, LookupParameter lp)¶ Load lookup parameter.
Load a full tensor of lookup parameters into the computation graph. Normally lookup parameters are accessed by using the lookup() function to grab a single element. However, in some cases we’ll want to access all of the parameters in the entire set of lookup parameters for some reason. In this case you can use this function. In this case, the first dimensions in the returned tensor will be equivalent to the dimensions that we would get if we get calling the lookup() function, and the size of the final dimension will be equal to the size of the vocabulary.
- Return
- An expression representing lp
- Parameters
g
: Computation graphlp
: LookupParameter object to load
-
Expression dynet
::
const_parameter
(ComputationGraph &g, Parameter p)¶ Load constant parameters.
Load parameters into the computation graph, but prevent them from being updated when performing parameter update.
- Return
- An expression representing the constant p
- Parameters
g
: Computation graphp
: Parameter object to load
-
Expression dynet
::
const_parameter
(ComputationGraph &g, LookupParameter lp)¶ Load constant lookup parameters.
Load lookup parameters into the computation graph, but prevent them from being updated when performing parameter update.
- Return
- An expression representing the constant lp
- Parameters
g
: Computation graphlp
: LookupParameter object to load
-
Expression dynet
::
lookup
(ComputationGraph &g, LookupParameter p, unsigned index)¶ Look up parameter.
Look up parameters according to an index, and load them into the computation graph.
- Return
- An expression representing p[index]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindex
: Index of the parameters within p
-
Expression dynet
::
lookup
(ComputationGraph &g, LookupParameter p, const unsigned *pindex)¶ Look up parameters with modifiable index.
Look up parameters according to the *pindex, and load them into the computation graph. When *pindex changes, on the next computation of forward() the values will change.
- Return
- An expression representing p[*pindex]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindex
: Pointer index of the parameters within p
-
Expression dynet
::
const_lookup
(ComputationGraph &g, LookupParameter p, unsigned index)¶ Look up parameter.
Look up parameters according to an index, and load them into the computation graph. Do not perform gradient update on the parameters.
- Return
- A constant expression representing p[index]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindex
: Index of the parameters within p
-
Expression dynet
::
const_lookup
(ComputationGraph &g, LookupParameter p, const unsigned *pindex)¶ Constant lookup parameters with modifiable index.
Look up parameters according to the *pindex, and load them into the computation graph. When *pindex changes, on the next computation of forward() the values will change. However, gradient updates will not be performend.
- Return
- A constant expression representing p[*pindex]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindex
: Pointer index of the parameters within p
-
Expression dynet
::
lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> &indices)¶ Look up parameters.
The mini-batched version of lookup. The resulting expression will be a mini-batch of parameters, where the “i”th element of the batch corresponds to the parameters at the position specified by the “i”th element of “indices”
- Return
- An expression with the “i”th batch element representing p[indices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindices
: Index of the parameters at each position in the batch
-
Expression dynet
::
lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> *pindices)¶ Look up parameters.
The mini-batched version of lookup with modifiable parameter indices.
- Return
- An expression with the “i”th batch element representing p[*pindices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindices
: Pointer to lookup indices
-
Expression dynet
::
const_lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> &indices)¶ Look up parameters.
Mini-batched lookup that will not update the parameters.
- Return
- A constant expression with the “i”th batch element representing p[indices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadindices
: Lookup indices
-
Expression dynet
::
const_lookup
(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> *pindices)¶ Look up parameters.
Mini-batched lookup that will not update the parameters, with modifiable indices.
- Return
- A constant expression with the “i”th batch element representing p[*pindices[i]]
- Parameters
g
: Computation graphp
: LookupParameter object from which to loadpindices
: Lookup index pointers.
-
Expression dynet
::
zeros
(ComputationGraph &g, const Dim &d)¶ Create an input full of zeros.
Create an input full of zeros, sized according to dimensions d.
- Return
- A
d
dimensioned zero tensor - Parameters
g
: Computation graphd
: The dimensions of the input
-
Expression dynet
::
ones
(ComputationGraph &g, const Dim &d)¶ Create an input full of ones.
Create an input full of ones, sized according to dimensions d.
- Return
- A
d
dimensioned tensor of ones - Parameters
g
: Computation graphd
: The dimensions of the input
-
Expression dynet
::
constant
(ComputationGraph &g, const Dim &d, float val)¶ Create an input with one constant value.
Create an input full of
val
, sized according to dimensions d.- Return
- A
d
dimensioned tensor filled with valueval
- Parameters
g
: Computation graphd
: The dimensions of the inputval
: The value of the input
-
Expression dynet
::
random_normal
(ComputationGraph &g, const Dim &d)¶ Create a random normal vector.
Create a vector distributed according to normal distribution with mean 0, variance 1.
- Return
- A “d” dimensioned normally distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the input
-
Expression dynet
::
random_bernoulli
(ComputationGraph &g, const Dim &d, real p, real scale = 1.0f)¶ Create a random bernoulli vector.
Create a vector distributed according to bernoulli distribution with parameter p.
- Return
- A “d” dimensioned bernoulli distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the inputp
: The bernoulli p parameterscale
: A scaling factor for the output (“active” elements will receive this value)
-
Expression dynet
::
random_uniform
(ComputationGraph &g, const Dim &d, real left, real right)¶ Create a random uniform vector.
Create a vector distributed according to uniform distribution with boundaries left and right.
- Return
- A “d” dimensioned uniform distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the inputleft
: The left boundaryright
: The right boundary
-
Expression dynet
::
random_gumbel
(ComputationGraph &g, const Dim &d, real mu = 0.0, real beta = 1.0)¶ Create a random Gumbel sampled vector.
Create a vector distributed according to a Gumbel distribution with the specified parameters. (Currently only the defaults of mu=0.0 and beta=1.0 supported.
- Return
- A “d” dimensioned Gumbel distributed vector
- Parameters
g
: Computation graphd
: The dimensions of the inputmu
: The mu parameterbeta
: The beta parameter
Arithmetic Operations¶
These operations perform basic arithemetic over values in the graph.
-
Expression dynet
::
operator-
(const Expression &x)¶ Negation.
Negate the passed argument.
- Return
- The negation of x
- Parameters
x
: An input expression
-
Expression dynet
::
operator+
(const Expression &x, const Expression &y)¶ Expression addition.
Add two expressions of the same dimensions.
- Return
- The sum of x and y
- Parameters
x
: The first inputy
: The second input
-
Expression dynet
::
operator+
(const Expression &x, real y)¶ Scalar addition.
Add a scalar to an expression
- Return
- An expression equal to x, with every component increased by y
- Parameters
x
: The expressiony
: The scalar
-
Expression dynet
::
operator+
(real x, const Expression &y)¶ Scalar addition.
Add a scalar to an expression
- Return
- An expression equal to y, with every component increased by x
- Parameters
x
: The scalary
: The expression
-
Expression dynet
::
operator-
(const Expression &x, const Expression &y)¶ Expression subtraction.
Subtract one expression from another.
- Return
- An expression where the ith element is x_i minus y_i
- Parameters
x
: The expression from which to subtracty
: The expression to subtract
-
Expression dynet
::
operator-
(real x, const Expression &y)¶ Scalar subtraction.
Subtract an expression from a scalar
- Return
- An expression where the ith element is x_i minus y
- Parameters
x
: The scalar from which to subtracty
: The expression to subtract
-
Expression dynet
::
operator-
(const Expression &x, real y)¶ Scalar subtraction.
Subtract a scalar from an expression
- Return
- An expression where the ith element is x_i minus y
- Parameters
x
: The expression from which to subtracty
: The scalar to subtract
-
Expression dynet
::
operator*
(const Expression &x, const Expression &y)¶ Matrix multiplication.
Multiply two matrices together. Like standard matrix multiplication, the second dimension of x and the first dimension of y must match.
- Return
- An expression x times y
- Parameters
x
: The left-hand matrixy
: The right-hand matrix
-
Expression dynet
::
operator*
(const Expression &x, float y)¶ Matrix-scalar multiplication.
Multiply an expression component-wise by a scalar.
- Return
- An expression where the ith element is x_i times y
- Parameters
x
: The matrixy
: The scalar
-
Expression dynet
::
operator*
(float y, const Expression &x)¶ Matrix-scalar multiplication.
Multiply an expression component-wise by a scalar.
- Return
- An expression where the ith element is x_i times y
- Parameters
x
: The scalary
: The matrix
-
Expression dynet
::
operator/
(const Expression &x, float y)¶ Matrix-scalar division.
Divide an expression component-wise by a scalar.
- Return
- An expression where the ith element is x_i divided by y
- Parameters
x
: The matrixy
: The scalar
-
Expression dynet
::
affine_transform
(const std::initializer_list<Expression> &xs)¶ Affine transform.
This performs an affine transform over an arbitrary (odd) number of expressions held in the input initializer list xs. The first expression is the “bias,” which is added to the expression as-is. The remaining expressions are multiplied together in pairs, then added. A very common usage case is the calculation of the score for a neural network layer (e.g. b + Wz) where b is the bias, W is the weight matrix, and z is the input. In this case xs[0] = b, xs[1] = W, and xs[2] = z.
- Return
- An expression equal to: xs[0] + xs[1]*xs[2] + xs[3]*xs[4] + …
- Parameters
xs
: An initializer list containing an odd number of expressions
-
Expression dynet
::
sum
(const std::initializer_list<Expression> &xs)¶ Sum.
This performs an elementwise sum over all the expressions in xs
- Return
- An expression where the ith element is equal to xs[0][i] + xs[1][i] + …
- Parameters
xs
: An initializer list containing expressions
-
Expression dynet
::
sum_elems
(const Expression &x)¶ Sum all elements.
Sum all the elements in an expression.
- Return
- The sum of all of its elements
- Parameters
x
: The input expression
-
Expression dynet
::
moment_elems
(const Expression &x, unsigned r)¶ Compute moment over all elements.
Compute the moment of order \(r\), \(\frac 1 n\sum_{i=1}^nx_i^r\) over all the elements in each batch of the expression
- Return
- A scalar expression (with a potential batch dimension)
- Parameters
x
: The input mini-batched expressionr
: Order of the moment
-
Expression dynet
::
mean_elems
(const Expression &x)¶ Compute mean over all elements.
Computes \(\frac 1 n\sum_{i=1}^nx_i\) over all the elements in each batch of the expression
- Return
- A scalar expression (with a potential batch dimension)
- Parameters
x
: The input mini-batched expression
-
Expression dynet
::
std_elems
(const Expression &x)¶ Compute Standard deviation over all elements.
Computes \(\frac 1 n\sum_{i=1}^n(x_i -\mu)^2\) where \(\mu=\frac 1 n\sum_{i=1}^nx_i\) over all the elements in each batch of the expression
- Return
- A scalar expression (with a potential batch dimension)
- Parameters
x
: The input mini-batched expression
-
Expression dynet
::
sum_batches
(const Expression &x)¶ Sum over minibatches.
Sum an expression that consists of multiple minibatches into one of equal dimension but with only a single minibatch. This is useful for summing loss functions at the end of minibatch training.
- Return
- An expression with a single batch
- Parameters
x
: The input mini-batched expression
-
Expression dynet
::
moment_batches
(const Expression &x, unsigned r)¶ Compute moment over minibatches.
Compute the moment of order \(r\), \(\frac 1 n\sum_{i=1}^nx_i^r\) along the batch dimension
- Return
- An expression with a single batch
- Parameters
x
: The input mini-batched expressionr
: Order of the moment
-
Expression dynet
::
mean_batches
(const Expression &x)¶ Compute mean over minibatches.
Computes \(\frac 1 n\sum_{i=1}^nx_i\) along the batch dimension
- Return
- An expression with a single batch
- Parameters
x
: The input mini-batched expression
-
Expression dynet
::
std_batches
(const Expression &x)¶ Compute standard deviation over minibatches.
Computes \(\frac 1 n\sum_{i=1}^n(x_i -\mu)^2\) where \(\mu=\frac 1 n\sum_{i=1}^nx_i\) along the batch dimension
- Return
- A scalar expression (with a potential batch dimension)
- Parameters
x
: The input mini-batched expression
-
Expression dynet
::
sum_dim
(const Expression &x, const std::vector<unsigned> &dims, bool b = false)¶ Compute sum along a specific dimension or dimensions.
Compute the sum along a specific dimension or dimensions
- Return
- An expression with |d| less dimensions and possibly dropped batch dimension
- Parameters
x
: The input mini-batched expressiond
: Dimensions along which to reduceb
: Whether to include batch dimension (default: false)
-
Expression dynet
::
moment_dim
(const Expression &x, const std::vector<unsigned> &dims, unsigned r, bool b = false, unsigned n = 0)¶ Compute moment along a specific dimension.
Compute the moment of order \(r\), \(\frac 1 n\sum_{i=1}^nx_i^r\) along a specific dimension
- Return
- An expression with |d| less dimensions and possibly dropped batch dimension
- Parameters
x
: The input mini-batched expressiond
: Dimensions along which to reducer
: Order of the momentb
: Whether to include batch dimension (default: false)n
: If > 0, overwrite the n in the equation by this value, useful for masking (default: 0)
-
Expression dynet
::
mean_dim
(const Expression &x, const std::vector<unsigned> &dims, bool b = false, unsigned n = 0)¶ Compute mean along a specific dimension.
Computes \(\frac 1 n\sum_{i=1}^nx_i\) along a specific dimension
- Return
- An expression with |d| less dimensions and possibly dropped batch dimension
- Parameters
x
: The input mini-batched expressiond
: Dimensions along which to reduceb
: Whether to include batch dimension (default: false)n
: If > 0, overwrite the n in the equation by this value, useful for masking (default: 0)
-
Expression dynet
::
std_dim
(const Expression &x, const std::vector<unsigned> &dims, bool b = false, unsigned n = 0)¶ Compute standard deviation along an arbitrary dimension.
Computes \(\frac 1 n\sum_{i=1}^n(x_i -\mu)^2\) where \(\mu=\frac 1 n\sum_{i=1}^nx_i\) along an arbitrary dimension
- Return
- An expression with |d| less dimensions and possibly dropped batch dimension
- Parameters
x
: The input mini-batched expressiond
: Dimensions along which to reduceb
: Whether to include batch dimension (default: false)n
: If > 0, overwrite the n in the equation by this value, useful for masking (default: 0)
-
Expression dynet
::
average
(const std::initializer_list<Expression> &xs)¶ Average.
This performs an elementwise average over all the expressions in xs
- Return
- An expression where the ith element is equal to (xs[0][i] + xs[1][i] + …)/|xs|
- Parameters
xs
: An initializer list containing expressions
-
Expression dynet
::
sqrt
(const Expression &x)¶ Square root.
Elementwise square root.
- Return
- An expression where the ith element is equal to \(\sqrt(x_i)\)
- Parameters
x
: The input expression
-
Expression dynet
::
abs
(const Expression &x)¶ Absolute value.
Elementwise absolute value.
- Return
- An expression where the ith element is equal to \(\vert x_i\vert\)
- Parameters
x
: The input expression
-
Expression dynet
::
erf
(const Expression &x)¶ Gaussian error function.
Elementwise calculation of the Gaussian error function
- Return
- An expression where the ith element is equal to erf(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
asin
(const Expression &x)¶ Inverse sine.
Elementwise calculation of the inverse sine
- Return
- An expression where the ith element is equal to asin(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
acos
(const Expression &x)¶ Inverse cosine.
Elementwise calculation of the inverse cosine
- Return
- An expression where the ith element is equal to acos(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
atan
(const Expression &x)¶ Inverse tangent.
Elementwise calculation of the inverse tangent
- Return
- An expression where the ith element is equal to atan(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
sin
(const Expression &x)¶ Sine.
Elementwise calculation of the sine
- Return
- An expression where the ith element is equal to sin(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
cos
(const Expression &x)¶ Cosine.
Elementwise calculation of the cosine
- Return
- An expression where the ith element is equal to cos(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
tan
(const Expression &x)¶ Tangent.
Elementwise calculation of the tangent
- Return
- An expression where the ith element is equal to tan(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
sinh
(const Expression &x)¶ Hyperbolic sine.
Elementwise calculation of the hyperbolic sine
- Return
- An expression where the ith element is equal to sinh(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
cosh
(const Expression &x)¶ Hyperbolic cosine.
Elementwise calculation of the hyperbolic cosine
- Return
- An expression where the ith element is equal to cosh(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
tanh
(const Expression &x)¶ Hyperbolic tangent.
Elementwise calculation of the hyperbolic tangent
- Return
- An expression where the ith element is equal to tanh(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
asinh
(const Expression &x)¶ Inverse hyperbolic sine.
Elementwise calculation of the inverse hyperbolic sine
- Return
- An expression where the ith element is equal to asinh(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
acosh
(const Expression &x)¶ Inverse hyperbolic cosine.
Elementwise calculation of the inverse hyperbolic cosine
- Return
- An expression where the ith element is equal to acosh(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
atanh
(const Expression &x)¶ Inverse hyperbolic tangent.
Elementwise calculation of the inverse hyperbolic tangent
- Return
- An expression where the ith element is equal to atanh(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
exp
(const Expression &x)¶ Natural exponent.
Calculate elementwise y_i = e^{x_i}
- Return
- An expression where the ith element is equal to e^{x_i}
- Parameters
x
: The input expression
-
Expression dynet
::
square
(const Expression &x)¶ Square.
Calculate elementwise y_i = x_i^2
- Return
- An expression where the ith element is equal to x_i^2
- Parameters
x
: The input expression
-
Expression dynet
::
cube
(const Expression &x)¶ Cube.
Calculate elementwise y_i = x_i^3
- Return
- An expression where the ith element is equal to x_i^3
- Parameters
x
: The input expression
-
Expression dynet
::
lgamma
(const Expression &x)¶ Log gamma.
Calculate elementwise y_i = ln(gamma(x_i))
- Return
- An expression where the ith element is equal to ln(gamma(x_i))
- Parameters
x
: The input expression
-
Expression dynet
::
log
(const Expression &x)¶ Logarithm.
Calculate the elementwise natural logarithm y_i = ln(x_i)
- Return
- An expression where the ith element is equal to ln(x_i)
- Parameters
x
: The input expression
-
Expression dynet
::
logistic
(const Expression &x)¶ Logistic sigmoid function.
Calculate elementwise y_i = 1/(1+e^{-x_i})
- Return
- An expression where the ith element is equal to y_i = 1/(1+e^{-x_i})
- Parameters
x
: The input expression
-
Expression dynet
::
rectify
(const Expression &x)¶ Rectifier.
Calculate elementwise the recitifer (ReLU) function y_i = max(x_i,0)
- Return
- An expression where the ith element is equal to max(x_i,0)
- Parameters
x
: The input expression
-
Expression dynet
::
elu
(const Expression &x, float alpha = 1.f)¶ Exponential Linear Unit.
Calculate elementwise the function
\( y_i = \left\{\begin{array}{lr} x_i, & \text{if } x>0\\ \alpha\times(e^{x_i} - 1), & \text{if }x\leqslant 0\\ \end{array}\right. \)
Reference: Clevert et al., 2015
- Return
- An expression where the ith element is equal to \(\text{ELU}(x_i, \alpha)\)
- Parameters
x
: The input expression
-
Expression dynet
::
selu
(const Expression &x)¶ Scaled Exponential Linear Unit (SELU)
Calculate elementwise the function
\( y_i = \lambda\times\left\{\begin{array}{lr} x_i, & \text{if } x>0\\ \alpha\times(e^{x_i} - 1), & \text{if }x\leqslant 0\\ \end{array}\right. \)
With \( \begin{split} \lambda &=\texttt{1.0507009873554804934193349852946}\\ \alpha &=\texttt{1.6732632423543772848170429916717}\\ \end{split} \)
Reference: Klambaouer et al., 2017
- Return
- An expression where the ith element is equal to \(\text{SELU}(x_i)\)
- Parameters
x
: The input expression
-
Expression dynet
::
silu
(const Expression &x, float beta = 1.f)¶ SILU / SiL / Swish.
Calculate elementwise y_i = x_i / (1 + e^{-beta * x_i})
Reference: Hendrycks and Gimpel, 2016, Elfwing et al, 2017, and Ramachandran et al., 2017
- Return
- An expression where the ith element is equal to y_i = x_i / (1 + e^{-beta * x_i})
- Parameters
x
: The input expression
-
Expression dynet
::
softsign
(const Expression &x)¶ Soft Sign.
Calculate elementwise the softsign function y_i = x_i/(1+|x_i|)
- Return
- An expression where the ith element is equal to x_i/(1+|x_i|)
- Parameters
x
: The input expression
-
Expression dynet
::
pow
(const Expression &x, const Expression &y)¶ Power function.
Calculate an output where the ith element is equal to x_i^y_i
- Return
- An expression where the ith element is equal to x_i^y_i
- Parameters
x
: The input expressiony
: The exponent expression
-
Expression dynet
::
min
(const Expression &x, const Expression &y)¶ Minimum.
Calculate an output where the ith element is min(x_i,y_i)
- Return
- An expression where the ith element is equal to min(x_i,y_i)
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression dynet
::
max
(const Expression &x, const Expression &y)¶ Maximum.
Calculate an output where the ith element is max(x_i,y_i)
- Return
- An expression where the ith element is equal to max(x_i,y_i)
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression dynet
::
max
(const std::initializer_list<Expression> &xs)¶ Max.
This performs an elementwise max over all the expressions in xs
- Return
- An expression where the ith element is equal to max(xs[0][i], xs[1][i], …)
- Parameters
xs
: An initializer list containing expressions
-
Expression dynet
::
dot_product
(const Expression &x, const Expression &y)¶ Dot Product.
Calculate the dot product sum_i x_i*y_i
- Return
- An expression equal to the dot product
- Parameters
x
: The input expressiony
: The input expression
-
Expression dynet
::
circ_conv
(const Expression &u, const Expression &v)¶ Circular convolution.
Calculate the circular convolution
- Return
- An expression equal to the circular convolution
- Parameters
x
: The input expressiony
: The input expression
-
Expression dynet
::
circ_corr
(const Expression &u, const Expression &v)¶ Circular correlation.
Calculate the circular correlation
- Return
- An expression equal to the circular correlation
- Parameters
x
: The input expressiony
: The input expression
-
Expression dynet
::
cmult
(const Expression &x, const Expression &y)¶ Componentwise multiply.
Multiply two expressions component-wise, broadcasting dimensions if necessary as follows:
- When number of dimensions differ, we add dimensions of size 1 to make the number of dimensions match
- Now, every dimensions is required to have matching size, or one of the dimensions must equal 1 (in which case it will be broadcasted)
- In the same way, the batch dimension must match, or equal 1 in which case it will be broadcasted
- The resulting tensor’s dimensionality is thus determined as the max of both inputs at every position
- Return
- An expression where the ith element is equal to x_i*y_i
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression dynet
::
cdiv
(const Expression &x, const Expression &y)¶ Componentwise division.
Divide an expressions component-wise by another, broadcasting dimensions (currently only of the second expression!) if necessary as follows:
- When number of dimensions differ, we add dimensions of size 1 to make the number of dimensions match
- Now, every dimensions is required to have matching size, or the dim size of the right expression must equal 1 (in which case it will be broadcasted)
- In the same way, the batch sizes must match, or the batch size of the right expression must equal 1 in which case it will be broadcasted
- The resulting tensor’s dimensionality is thus determined as the max of both inputs at every position
- Return
- An expression where the ith element is equal to x_i/y_i
- Parameters
x
: The first input expressiony
: The second input expression
-
Expression dynet
::
colwise_add
(const Expression &x, const Expression &bias)¶ Columnwise addition.
Add vector “bias” to each column of matrix “x”
- Return
- An expression where bias is added to each column of x
- Parameters
x
: An MxN matrixbias
: A length M vector
Probability/Loss Operations¶
These operations are used for calculating probabilities, or calculating loss functions for use in training.
-
Expression dynet
::
softmax
(const Expression &x, unsigned d = 0)¶ Softmax.
The softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the e^{x[i]}/{sum_j e^{x[j]}}.
- Return
- A vector or matrix after calculating the softmax
- Parameters
x
: A vector or matrixd
: dimension to normalize over (default: 0)
-
Expression dynet
::
log_softmax
(const Expression &x)¶ Log softmax.
The log softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the e^{x[i]}/{sum_j e^{x[j]}}, then takes the log
- Return
- A vector or matrix after calculating the log softmax
- Parameters
x
: A vector or matrix
-
Expression dynet
::
log_softmax
(const Expression &x, const std::vector<unsigned> &restriction)¶ Restricted log softmax.
The log softmax function calculated over only a subset of the vector elements. The elements to be included are set by the
restriction
variable. All elements not included inrestriction
are set to negative infinity.- Return
- A vector with the log softmax over the specified elements
- Parameters
x
: A vector over which to calculate the softmaxrestriction
: The elements over which to calculate the softmax
-
Expression dynet
::
logsumexp_dim
(const Expression &x, unsigned d)¶ Log, sum, exp by dimension.
The “logsumexp” function calculated over a particular dimension \(ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.
- Return
- The result.
- Parameters
x
: Expression with respect to which to calculate the logsumexp.d
: The dimension along which to do the logsumexp.
-
Expression dynet
::
logsumexp
(const std::initializer_list<Expression> &xs)¶ Log, sum, exp.
The elementwise “logsumexp” function that calculates \(ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.
- Return
- The result.
- Parameters
xs
: Expressions with respect to which to calculate the logsumexp.
-
Expression dynet
::
pickneglogsoftmax
(const Expression &x, unsigned v)¶ Negative softmax log likelihood.
This function takes in a vector of scores
x
, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the elementv
. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.- Return
- The negative log likelihood of element
v
after taking the softmax - Parameters
x
: A vector of scoresv
: The element with which to calculate the loss
-
Expression dynet
::
pickneglogsoftmax
(const Expression &x, const unsigned *pv)¶ Modifiable negative softmax log likelihood.
This function calculates the negative log likelihood after the softmax with respect to index
*pv
. This computes the same value as the previous function that passes the indexv
by value, but instead passes by pointer so the value*pv
can be modified without re-constructing the computation graph. This can be used in situations where we want to create a computation graph once, then feed it different data points.- Return
- The negative log likelihood of element
*pv
after taking the softmax - Parameters
x
: A vector of scorespv
: A pointer to the index of the correct element
-
Expression dynet
::
pickneglogsoftmax
(const Expression &x, const std::vector<unsigned> &v)¶ Batched negative softmax log likelihood.
This function is similar to standard pickneglogsoftmax, but calculates loss with respect to multiple batch elements. The input will be a mini-batch of score vectors where the number of batch elements is equal to the number of indices in
v
.- Return
- The negative log likelihoods over all the batch elements
- Parameters
x
: An expression with vectors of scores over N batch elementsv
: A size-N vector indicating the index with respect to all the batch elements
-
Expression dynet
::
pickneglogsoftmax
(const Expression &x, const std::vector<unsigned> *pv)¶ Modifiable batched negative softmax log likelihood.
This function is a combination of modifiable pickneglogsoftmax and batched pickneglogsoftmax:
pv
can be modified without re-creating the computation graph.- Return
- The negative log likelihoods over all the batch elements
- Parameters
x
: An expression with vectors of scores over N batch elementspv
: A pointer to the indexes
-
Expression dynet
::
hinge
(const Expression &x, unsigned index, float m = 1.0)¶ Hinge loss.
This expression calculates the hinge loss, formally expressed as: \( \text{hinge}(x,index,m) = \sum_{i \ne index} \max(0, m-x[index]+x[i]). \)
- Return
- The hinge loss of candidate
index
with respect to marginm
- Parameters
x
: A vector of scoresindex
: The index of the correct candidatem
: The margin
-
Expression dynet
::
hinge
(const Expression &x, const unsigned *pindex, float m = 1.0)¶ Modifiable hinge loss.
This function calculates the hinge loss with with respect to index
*pindex
. This computes the same value as the previous function that passes the indexindex
by value, but instead passes by pointer so the value*pindex
can be modified without re-constructing the computation graph. This can be used in situations where we want to create a computation graph once, then feed it different data points.- Return
- The hinge loss of candidate
*pindex
with respect to marginm
- Parameters
x
: A vector of scorespindex
: A pointer to the index of the correct candidatem
: The margin
-
Expression dynet
::
hinge
(const Expression &x, const std::vector<unsigned> &indices, float m = 1.0)¶ Batched hinge loss.
The same as hinge loss, but for the case where
x
is a mini-batched tensor withindices.size()
batch elements, andindices
is a vector indicating the index of each of the correct elements for these elements.- Return
- The hinge loss of each mini-batch
- Parameters
x
: A mini-batch of vectors withindices.size()
batch elementsindices
: The indices of the correct candidates for each batch elementm
: The margin
-
Expression dynet
::
hinge
(const Expression &x, const std::vector<unsigned> *pindices, float m = 1.0)¶ Batched modifiable hinge loss.
A combination of the previous batched and modifiable hinge loss functions, where vector
*pindices
can be modified.- Return
- The hinge loss of each mini-batch
- Parameters
x
: A mini-batch of vectors withindices.size()
batch elementspindices
: Pointer to the indices of the correct candidates for each batch elementm
: The margin
-
Expression dynet
::
hinge_dim
(const Expression &x, const std::vector<unsigned> &indices, unsigned d = 0, float m = 1.0)¶ Dimensionwise hinge loss.
This expression calculates the hinge loss over a particular dimension
d
.- Return
- A vector of hinge losses for each index in
indices
. - Parameters
x
: A matrix of scoresindices
: The indices of the correct candidate (equal in length to the dimension not specified by “d”)d
: The dimension over which to calculate the loss (0 or 1)m
: The margin
-
Expression dynet
::
hinge_dim
(const Expression &x, const std::vector<unsigned> *pindex, unsigned d = 0, float m = 1.0)¶ Modifiable dimensionwise hinge loss.
This function calculates the modifiable version of dimensionwise hinge loss.
- Return
- A vector of hinge losses for each index in
indices
. - Parameters
x
: A vector of scorespindex
: A pointer to the index of the correct candidated
: The dimension over which to calculate the loss (0 or 1)m
: The margin
-
Expression dynet
::
hinge_dim
(const Expression &x, const std::vector<std::vector<unsigned>> &indices, unsigned d = 0, float m = 1.0)¶ Batched dimensionwise hinge loss.
The same as dimensionwise hinge loss, but for the case where
x
is a mini-batched tensor withindices.size()
batch elements.- Return
- A vector of hinge losses for each mini-batch
- Parameters
x
: A mini-batch of vectors withindices.size()
batch elementsindices
: The indices of the correct candidates for each batch elementd
: The dimension over which to calculate the loss (0 or 1)m
: The margin
-
Expression dynet
::
hinge_dim
(const Expression &x, const std::vector<std::vector<unsigned>> *pindices, unsigned d = 0, float m = 1.0)¶ Batched modifiable hinge loss.
A combination of the previous batched and modifiable hinge loss functions, where vector
*pindices
can be modified.- Return
- The hinge loss of each mini-batch
- Parameters
x
: A mini-batch of vectors withindices.size()
batch elementspindices
: Pointer to the indices of the correct candidates for each batch elementd
: The dimension over which to calculate the loss (0 or 1)m
: The margin
-
Expression dynet
::
sparsemax
(const Expression &x)¶ Sparsemax.
The sparsemax function (Martins et al. 2016), which is similar to softmax, but induces sparse solutions where most of the vector elements are zero. Note: This function is not yet implemented on GPU.
- Return
- The sparsemax of the scores
- Parameters
x
: A vector of scores
-
Expression dynet
::
sparsemax_loss
(const Expression &x, const std::vector<unsigned> &target_support)¶ Sparsemax loss.
The sparsemax loss function (Martins et al. 2016), which is similar to softmax loss, but induces sparse solutions where most of the vector elements are zero. It has a gradient similar to the sparsemax function and thus is useful for optimizing when the sparsemax will be used at test time. Note: This function is not yet implemented on GPU.
- Return
- The sparsemax loss of the labels
- Parameters
x
: A vector of scorestarget_support
: The target correct labels.
-
Expression dynet
::
sparsemax_loss
(const Expression &x, const std::vector<unsigned> *ptarget_support)¶ Modifiable sparsemax loss.
Similar to the sparsemax loss, but with ptarget_support being a pointer to a vector, allowing it to be modified without re-creating the compuation graph. Note: This function is not yet implemented on GPU.
- Return
- The sparsemax loss of the labels
- Parameters
x
: A vector of scoresptarget_support
: A pointer to the target correct labels.
-
Expression dynet
::
constrained_softmax
(const Expression &x, const Expression &y)¶ Constrained softmax.
The constrained softmax function. Note: This function is not yet implemented on GPU.
- Return
- The constrained softmax of the scores.
- Parameters
x
: A vector of scoresy
: A vector of upper bound constraints on probabilities
-
Expression dynet
::
squared_norm
(const Expression &x)¶ Squared norm.
The squared L2 norm of the values of x: \(\sum_i x_i^2\).
- Return
- The squared L2 norm
- Parameters
x
: A vector of values
-
Expression dynet
::
l2_norm
(const Expression &x)¶ L2 norm.
The L2 norm of the values of x: \(\sum_i x_i^2\).
- Return
- The L2 norm
- Parameters
x
: A vector of values
-
Expression dynet
::
squared_distance
(const Expression &x, const Expression &y)¶ Squared distance.
The squared distance between values of
x
andy
: \(\sum_i (x_i-y_i)^2\).- Return
- The squared distance
- Parameters
x
: A vector of valuesy
: Another vector of values
-
Expression dynet
::
l1_distance
(const Expression &x, const Expression &y)¶ L1 distance.
The L1 distance between values of
x
andy
: \(\sum_i |x_i-y_i|\).- Return
- The squared distance
- Parameters
x
: A vector of valuesy
: Another vector of values
-
Expression dynet
::
huber_distance
(const Expression &x, const Expression &y, float c = 1.345f)¶ Huber distance.
The huber distance between values of
x
andy
parameterized byc,
\(\sum_i L_c(x_i, y_i)\) where:\( L_c(x, y) = \begin{cases}{lr} \frac{1}{2}(y - x)^2 & \textrm{for } |y - f(x)| \le c, \\ c\, |y - f(x)| - \frac{1}{2}c^2 & \textrm{otherwise.} \end{cases} \)
- Return
- The huber distance
- Parameters
x
: A vector of valuesy
: Another vector of valuesc
: The parameter of the huber distance parameterizing the cuttoff
-
Expression dynet
::
binary_log_loss
(const Expression &x, const Expression &y)¶ Binary log loss.
The log loss of a binary decision according to the sigmoid sigmoid function \(- \sum_i (y_i * ln(x_i) + (1-y_i) * ln(1-x_i)) \)
- Return
- The log loss of the sigmoid function
- Parameters
x
: A vector of valuesy
: A vector of true answers
-
Expression dynet
::
pairwise_rank_loss
(const Expression &x, const Expression &y, real m = 1.0)¶ Pairwise rank loss.
A margin-based loss, where every margin violation for each pair of values is penalized: \(\sum_i max(m - x_i + y_i, 0)\)
- Return
- The pairwise rank loss
- Parameters
x
: A vector of valuesy
: A vector of true answersm
: The margin
-
Expression dynet
::
poisson_loss
(const Expression &x, unsigned y)¶ Poisson loss.
The negative log probability of
y
according to a Poisson distribution with parameterx
. Useful in Poisson regression where, we try to predict the parameters of a Possion distribution to maximize the probability of datay
.- Return
- The Poisson loss
- Parameters
x
: The parameter of the Poisson distribution.y
: The target value
-
Expression dynet
::
poisson_loss
(const Expression &x, const unsigned *py)¶ Modifiable Poisson loss.
Similar to Poisson loss, but with the target value passed by pointer so that it can be modified without re-constructing the computation graph.
- Return
- The Poisson loss
- Parameters
x
: The parameter of the Poisson distribution.py
: A pointer to the target value
Flow/Shaping Operations¶
These operations control the flow of information through the graph, or the shape of the vectors/tensors used in the graph.
-
Expression dynet
::
nobackprop
(const Expression &x)¶ Prevent backprop.
This node has no effect on the forward pass, but prevents gradients from flowing backward during the backward pass. This is useful when there’s a subgraph for which you don’t want loss passed back to the parameters.
- Return
- The new expression
- Parameters
x
: The input expression
-
Expression dynet
::
flip_gradient
(const Expression &x)¶ Negative backprop.
This node has no effect on the forward pass, but takes negative on backprop process. This operation is widely used in adversarial networks.
- Return
- An output expression containing the same as input (only effects on backprop process)
- Parameters
x
: The input expression
-
Expression dynet
::
reshape
(const Expression &x, const Dim &d)¶ Reshape to another size.
This node reshapes a tensor to another size, without changing the underlying layout of the data. The layout of the data in DyNet is column-major, so if we have a 3x4 matrix
\( \begin{pmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} \\ \end{pmatrix} \)
and transform it into a 2x6 matrix, it will be rearranged as:
\( \begin{pmatrix} x_{1,1} & x_{3,1} & x_{2,2} & x_{1,3} & x_{3,3} & x_{2,4} \\ x_{2,1} & x_{1,2} & x_{3,2} & x_{2,3} & x_{1,4} & x_{3,4} \\ \end{pmatrix} \)
**Note:** This is O(1) for forward, and O(n) for backward.
- Return
- The reshaped expression
- Parameters
x
: The input expressiond
: The new dimensions
-
Expression dynet::transpose(const Expression & x, const std::vector< unsigned > & dims = {1, 0})
Transpose a matrix.
Transpose a matrix or tensor, or if dims is specified shuffle the dimensions arbitrarily. Note: This is O(1) if either the row or column dimension is 1, and O(n) otherwise.
- Return
- The transposed/shuffled expression
- Parameters
x
: The input expressiondims
: The dimensions to swap. The ith dimension of the output will be equal to the dims[i] dimension of the input. dims must have the same number of dimensions as x.
-
Expression dynet
::
select_rows
(const Expression &x, const std::vector<unsigned> &rows)¶ Select rows.
Select a subset of rows of a matrix.
- Return
- An expression containing the selected rows
- Parameters
x
: The input expressionrows
: The rows to extract
-
Expression dynet
::
select_rows
(const Expression &x, const std::vector<unsigned> *prows)¶ Modifiable select rows.
Select a subset of rows of a matrix, where the elements of prows can be modified without re-creating the computation graph.
- Return
- An expression containing the selected rows
- Parameters
x
: The input expressionprows
: The rows to extract
-
Expression dynet
::
select_cols
(const Expression &x, const std::vector<unsigned> &cols)¶ Select columns.
Select a subset of columns of a matrix. select_cols is more efficient than select_rows since DyNet uses column-major order.
- Return
- An expression containing the selected columns
- Parameters
x
: The input expressioncolumns
: The columns to extract
-
Expression dynet
::
select_cols
(const Expression &x, const std::vector<unsigned> *pcols)¶ Modifiable select columns.
Select a subset of columns of a matrix, where the elements of pcols can be modified without re-creating the computation graph.
- Return
- An expression containing the selected columns
- Parameters
x
: The input expressionpcolumns
: The columns to extract
-
Expression dynet
::
pick
(const Expression &x, unsigned v, unsigned d = 0)¶ Pick element.
Pick a single element/row/column/sub-tensor from an expression. This will result in the dimension of the tensor being reduced by 1.
- Return
- The value of x[v] along dimension d
- Parameters
x
: The input expressionv
: The index of the element to selectd
: The dimension along which to choose the element
-
Expression dynet
::
pick
(const Expression &x, const std::vector<unsigned> &v, unsigned d = 0)¶ Batched pick.
Pick elements from multiple batches.
- Return
- A mini-batched expression containing the picked elements
- Parameters
x
: The input expressionv
: A vector of indicies to choose, one for each batch in the input expression.d
: The dimension along which to choose the elements
-
Expression dynet
::
pick
(const Expression &x, const unsigned *pv, unsigned d = 0)¶ Modifiable pick element.
Pick a single element from an expression, where the index is passed by pointer so we do not need to re-create the computation graph every time.
- Return
- The value of x[*pv]
- Parameters
x
: The input expressionpv
: Pointer to the index of the element to selectd
: The dimension along which to choose the elements
-
Expression dynet
::
pick
(const Expression &x, const std::vector<unsigned> *pv, unsigned d = 0)¶ Modifiable batched pick element.
Pick multiple elements from an input expression, where the indices are passed by pointer so we do not need to re-create the computation graph every time.
- Return
- A mini-batched expression containing the picked elements
- Parameters
x
: The input expressionpv
: A pointer to vector of indicies to choosed
: The dimension along which to choose the elements
-
Expression dynet
::
pick_range
(const Expression &x, unsigned s, unsigned e, unsigned d = 0)¶ Pick range of elements.
Pick a range of elements from an expression.
- Return
- The value of {x[v],…,x[u]}
- Parameters
x
: The input expressions
: The start indexe
: The end indexd
: The dimension along which to pick
-
Expression dynet
::
pick_batch_elem
(const Expression &x, unsigned v)¶ (Modifiable) Pick batch element.
Pick batch element from a batched expression. For a Tensor with 3 batch elements:
\( \begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)
pick_batch_elem(t, 1) will return a Tensor of
\( \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \)
- Return
- The expression of picked batch element. The picked element is a tensor whose
bd
equals to one. - Parameters
x
: The input expressionv
: The index of the batch element to be picked.
-
Expression dynet
::
pick_batch_elems
(const Expression &x, const std::vector<unsigned> &v)¶ (Modifiable) Pick batch elements.
Pick several batch elements from a batched expression. For a Tensor with 3 batch elements:
\( \begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)
pick_batch_elems(t, {1, 2}) will return a Tensor of with 2 batch elements:
\( \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)
- Return
- The expression of picked batch elements. The batch elements is a tensor whose
bd
equals to the size of vectorv
. - Parameters
x
: The input expressionv
: A vector of indicies of the batch elements to be picked.
-
Expression dynet
::
pick_batch_elem
(const Expression &x, const unsigned *v)¶ Pick batch element.
Pick batch element from a batched expression.
- Return
- The expression of picked batch element. The picked element is a tensor whose
bd
equals to one. - Parameters
x
: The input expressionv
: A pointer to the index of the correct element to be picked.
-
Expression dynet
::
pick_batch_elems
(const Expression &x, const std::vector<unsigned> *pv)¶ Pick batch elements.
Pick several batch elements from a batched expression.
- Return
- The expression of picked batch elements. The batch elements is a tensor whose
bd
equals to the size of vectorv
. - Parameters
x
: The input expressionv
: A pointer to the indexes
-
Expression dynet
::
concatenate_to_batch
(const std::initializer_list<Expression> &xs)¶ Concatenate list of expressions to a single batched expression.
Perform a concatenation of several expressions along the batch dimension. All expressions must have the same shape except for the batch dimension.
- Return
- The expression with the batch dimensions concatenated
- Parameters
xs
: The input expressions
-
Expression dynet
::
strided_select
(const Expression &x, const std::vector<int> &strides, const std::vector<int> &from = {}, const std::vector<int> &to = {})¶ Strided select in multiple dimensions.
Select a range and/or stride of elements from an expression.
- Return
- The value of x[from[0]:to[0]:strides[0],..] (as it would be in numpy syntax)
- Parameters
x
: The input expressionstrides
: List of strides for each dimension, must be >= 1. Dimensions not included default to 1. Batch dimension can be included as very last dimension.from
: List of 0-based offsets (inclusive) for each dimension, must be >= 0. Dimensions not included default to 0. Batch dimension can be included as very last dimension.to
: List of highest 0-based index to select (exclusive) for each dimension, must be >= 0. Dimensions not included default to the corresponding dim size. Batch dimension can be included as very last dimension.
-
Expression dynet
::
concatenate_cols
(const std::initializer_list<Expression> &xs)¶ Concatenate columns.
Perform a concatenation of the columns in multiple expressions. All expressions must have the same number of rows.
- Return
- The expression with the columns concatenated
- Parameters
xs
: The input expressions
-
Expression dynet
::
concatenate
(const std::initializer_list<Expression> &xs, unsigned d = 0)¶ Concatenate.
Perform a concatenation of multiple expressions along a particular dimension. All expressions must have the same dimensions except for the dimension to be concatenated (rows by default).
- Return
- The expression with the specified dimension concatenated
- Parameters
xs
: The input expressionsd
: The dimension along which to perform concatenation
-
Expression dynet
::
max_dim
(const Expression &x, unsigned d = 0)¶ Max out through a dimension.
Select out a element/row/column/sub-tensor from an expression, with maximum value along a given dimension. This will result in the dimension of the tensor being reduced by 1.
- Return
- An expression of sub-tensor with max value along dimension d
- Parameters
x
: The input expressiond
: The dimension along which to choose the element
-
Expression dynet
::
min_dim
(const Expression &x, unsigned d = 0)¶ Min out through a dimension.
Select out a element/row/column/sub-tensor from an expression, with minimum value along a given dimension. This will result in the dimension of the tensor being reduced by 1.
- Return
- An expression of sub-tensor with min value along dimension d
- Parameters
x
: The input expressiond
: The dimension along which to choose the element
Noise Operations¶
These operations are used to add noise to the graph for purposes of making learning more robust.
-
Expression dynet
::
noise
(const Expression &x, real stddev)¶ Gaussian noise.
Add gaussian noise to an expression.
- Return
- The noised expression
- Parameters
x
: The input expressionstddev
: The standard deviation of the gaussian
-
Expression dynet
::
dropout
(const Expression &x, real p)¶ Dropout.
With a fixed probability, drop out (set to zero) nodes in the input expression, and scale the remaining nodes by 1/p. Note that there are two kinds of dropout:
- Regular dropout: where we perform dropout at training time and then scale outputs by p at test time.
- Inverted dropout: where we perform dropout and scaling at training time, and do not need to do anything at test time. DyNet implements the latter, so you only need to apply dropout at training time, and do not need to perform scaling and test time.
- Return
- The dropped out expression
- Parameters
x
: The input expressionp
: The dropout probability
-
Expression dynet
::
dropout_dim
(const Expression &x, unsigned d, real p)¶ Dropout along a specific dimension.
Identical to the dropout operation except the dropout mask is the same across one dimension. Use this if you want to drop columns or lines in a matrix for example
For now this only supports tensors of order <= 3 (with or without batch dimension)
- Return
- The dropped out expression
- Parameters
x
: The input expressiond
: The dimension along which to dropp
: The dropout probability
-
Expression dynet
::
dropout_batch
(const Expression &x, real p)¶ Dropout entire elements of a minibatch.
Identical to the dropout operation except entire batch elements are dropped
- Return
- The dropped out expression
- Parameters
x
: The input expressionp
: The dropout probability
-
Expression dynet
::
block_dropout
(const Expression &x, real p)¶ Block dropout.
Identical to the dropout operation, but either drops out all or no values in the expression, as opposed to making a decision about each value individually.
- Return
- The block dropout expression
- Parameters
x
: The input expressionp
: The block dropout probability
Tensor Operations¶
These operations are used for performing operations on higher order tensors.
Remark: Compiling the contraction operations takes a lot of time with CUDA. For this reason, only the CPU implementation is compiled by default. If you need those operations, you need to un-comment this line in the source before compiling. TODO: make this simpler.
-
Expression dynet
::
contract3d_1d
(const Expression &x, const Expression &y)¶ Contracts a rank 3 tensor and a rank 1 tensor into a rank 2 tensor.
The resulting tensor \(z\) has coordinates \(z_ij = \sum_k x_{ijk} y_k\)
- Return
- Matrix
- Parameters
x
: Rank 3 tensory
: Vector
-
Expression dynet
::
contract3d_1d_1d
(const Expression &x, const Expression &y, const Expression &z)¶ Contracts a rank 3 tensor and two rank 1 tensor into a rank 1 tensor.
This is the equivalent of calling
contract3d_1d
and then performing a matrix vector multiplication.The resulting tensor \(t\) has coordinates \(t_i = \sum_{j,k} x_{ijk} y_k z_j\)
- Return
- Vector
- Parameters
x
: Rank 3 tensory
: Vectorz
: Vector
-
Expression dynet
::
contract3d_1d_1d
(const Expression &x, const Expression &y, const Expression &z, const Expression &b)¶ Same as
contract3d_1d_1d
with an additional bias parameter.This is the equivalent of calling
contract3d_1d
and then performing an affine transform.The resulting tensor \(t\) has coordinates \(t_i = b_i + \sum_{j,k} x_{ijk} y_k z_j\)
- Return
- Vector
- Parameters
x
: Rank 3 tensory
: Vectorz
: Vectorb
: Bias vector
-
Expression dynet
::
contract3d_1d
(const Expression &x, const Expression &y, const Expression &b)¶ Same as
contract3d_1d
with an additional bias parameter.The resulting tensor \(z\) has coordinates \(z_{ij} = b_{ij}+\sum_k x_{ijk} y_k\)
- Return
- Matrix
- Parameters
x
: Rank 3 tensory
: Vectorb
: Bias matrix
Linear Algebra Operations¶
These operations are used for performing various operations common in linear algebra.
-
Expression dynet
::
inverse
(const Expression &x)¶ Matrix Inverse.
Takes the inverse of a matrix (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158). Note that back-propagating through an inverted matrix can also be the source of stability problems sometimes.
- Return
- The inverse of the matrix
- Parameters
x
: A square matrix
-
Expression dynet
::
logdet
(const Expression &x)¶ Log determinant.
Takes the log of the determinant of a matrix. (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158).
- Return
- The log of its determinant
- Parameters
x
: A square matrix
-
Expression dynet
::
trace_of_product
(const Expression &x, const Expression &y)¶ Trace of Matrix Product.
Takes the trace of the product of matrices. (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158).
- Return
- trace(x1 * x2)
- Parameters
x1
: A matrixx2
: Another matrix
Convolution Operations¶
These operations are convolution-related.
-
Expression dynet
::
conv2d
(const Expression &x, const Expression &f, const std::vector<unsigned> &stride, bool is_valid = true)¶ conv2d without bias
2D convolution operator without bias parameters. ‘VALID’ and ‘SAME’ convolutions are supported. Think about when stride is 1, the distinction:
- SAME: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.
- VALID: output size shrinks by filter_size - 1, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume:
- Input feature maps: (XH x XW x XC) x N
- Filters: FH x FW x XC x FC, 4D tensor
- Strides: strides[0] and strides[1] are row (h) and col (w) stride, respectively.
For the SAME convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH) / float(strides[0]))
- YW = ceil(float(XW) / float(strides[1])) and the paddings are computed as:
- pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
- pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
- pad_top = pad_along_height / 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width / 2
- pad_right = pad_along_width - pad_left
For the VALID convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH - FH + 1) / float(strides[0]))
- YW = ceil(float(XW - FW + 1) / float(strides[1])) and the paddings are always zeros.
- Return
- The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
- Parameters
x
: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimensionf
: 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensorstride
: the row and column stridesis_valid
: ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’)
-
Expression dynet
::
conv2d
(const Expression &x, const Expression &f, const Expression &b, const std::vector<unsigned> &stride, bool is_valid = true)¶ conv2d with bias
2D convolution operator with bias parameters. ‘VALID’ and ‘SAME’ convolutions are supported. Think about when stride is 1, the distinction:
- SAME: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.
- VALID: output size shrinks by filter_size - 1, and the filters always sweep at valid positions inside the input maps. No padding needed.
In detail, assume:
- Input feature maps: XH x XW x XC x N
- Filters: FH x FW x XC x FC
- Strides: strides[0] and strides[1] are row (h) and col (w) stride, respectively.
For the SAME convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH) / float(strides[0]))
- YW = ceil(float(XW) / float(strides[1])) and the paddings are computed as:
- pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
- pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
- pad_top = pad_along_height / 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width / 2
- pad_right = pad_along_width - pad_left
For the VALID convolution: the output height (YH) and width (YW) are computed as:
- YH = ceil(float(XH - FH + 1) / float(strides[0]))
- YW = ceil(float(XW - FW + 1) / float(strides[1])) and the paddings are always zeros.
- Return
- The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
- Parameters
x
: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimensionf
: 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensorb
: The bias (1D: Ci)stride
: the row and column stridesis_valid
: ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’)
-
Expression dynet
::
maxpooling2d
(const Expression &x, const std::vector<unsigned> &ksize, const std::vector<unsigned> &stride, bool is_valid = true)¶ maxpooling2d
2D maxpooling operator.
- Return
- The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
- Parameters
x
: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimensionksize
: the height and width of the maxpooling2d window or kernelstride
: the row and column stridesis_valid
: ‘VALID’ or ‘SAME’ (see comments for conv2d) , default is True (‘VALID’)
Normalization Operations¶
This includes batch normalization and the likes.
-
Expression dynet
::
layer_norm
(const Expression &x, const Expression &g, const Expression &b)¶ Layer normalization.
Performs layer normalization :
\( \begin{split} \mu &= \frac 1 n \sum_{i=1}^n x_i\\ \sigma &= \sqrt{\frac 1 n \sum_{i=1}^n (x_i-\mu)^2}\\ y&=\frac {\boldsymbol{g}} \sigma \circ (\boldsymbol{x}-\mu) + \boldsymbol{b}\\ \end{split} \)
Reference : Ba et al., 2016
- Return
- An expression of the same dimension as
x
- Parameters
x
: Input expression (possibly batched)g
: Gain (same dimension as x, no batch dimension)b
: Bias (same dimension as x, no batch dimension)
-
Expression dynet
::
weight_norm
(const Expression &w, const Expression &g)¶ Weight normalization.
Performs weight normalization :
\( \begin{split} \hat{w} &= g\frac{w}{\Vert w\Vert}\\ \end{split} \)
Reference : Salimans, Kingma 2016
- Return
- An expression of the same dimension as
w
- Parameters
w
: Input expression (weight parameter)g
: Gain (scalar expression, usually also a parameter)
Device operations¶
These operations are device-related.
-
Expression dynet
::
to_device
(const Expression &x, Device *device)¶ Copy tensor between devices.
Copy tensor from x’s device to device
- Return
- An expression of x’s tensor in device
- Parameters
x
: Input expressiondevice
: Device to place return tensor
Builders¶
Builders combine together various operations to implement more complicated things such as recurrent, LSTM networks or hierarchical softmax
RNN Builders¶
-
struct dynet
::
CoupledLSTMBuilder
¶ - #include <lstm.h>
CoupledLSTMBuilder creates an LSTM unit with coupled input and forget gate as well as peepholes connections.
More specifically, here are the equations for the dynamics of this cell :
\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+W_{ic}c_{t-1}+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+W_{oc}c_{t}+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
Inherits from dynet::RNNBuilder
Public Functions
-
dynet::CoupledLSTMBuilder
CoupledLSTMBuilder
()¶ Default constructor.
-
dynet::CoupledLSTMBuilder
CoupledLSTMBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model)¶ Constructor for the LSTMBuilder.
- Parameters
layers
: Number of layersinput_dim
: Dimention of the input \(x_t\)hidden_dim
: Dimention of the hidden states \(h_t\) and \(c_t\)model
: ParameterCollection holding the parameters
-
unsigned dynet::CoupledLSTMBuilder
num_h0_components
() const¶ Number of components in
h_0
For
LSTMBuilder
, this corresponds to2 * layers
because it includes the initial cell state \(c_0\)- Return
2 * layers
-
std::vector<Expression> dynet::CoupledLSTMBuilder
get_s
(RNNPointer i) const¶ Get the final state of the hidden layer.
For
LSTMBuilder
, this consists of a vector of the memory cell values for each layer (l1, l2, l3), followed by the hidden state values- Return
- {c_{l1}, c_{l1}, …, h_{l1}, h_{l2}, …}
-
void dynet::CoupledLSTMBuilder
set_dropout
(float d)¶ Set the dropout rates to a unique value.
This has the same effect as
set_dropout(d,d_h,d_c)
except that all the dropout rates are set to the same value.- Parameters
d
: Dropout rate to be applied on all of \(x,h,c\)
-
void dynet::CoupledLSTMBuilder
set_dropout
(float d, float d_h, float d_c)¶ Set the dropout rates.
The dropout implemented here is an adaptation of the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\), \(\mathbf{z_c}\sim \mathrm{Bernoulli}(1-d_c)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :
\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{ih}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+W_{ic}(\frac 1 {1-d_c} {\mathbf{z_c}} \circ c_{t-1})+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{ch}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{oh}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+W_{oc}(\frac 1 {1-d_c} {\mathbf{z_c}} \circ c_{t})+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
- Parameters
d
: Dropout rate \(d_x\) for the input \(x_t\)d_h
: Dropout rate \(d_x\) for the output \(h_t\)d_c
: Dropout rate \(d_x\) for the cell \(c_t\)
-
void dynet::CoupledLSTMBuilder
disable_dropout
()¶ Set all dropout rates to 0.
This is equivalent to
set_dropout(0)
orset_dropout(0,0,0)
-
void dynet::CoupledLSTMBuilder
set_dropout_masks
(unsigned batch_size = 1)¶ Set dropout masks at the beginning of a sequence for a specific bathc size.
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
- Parameters
batch_size
: Batch size
-
ParameterCollection &dynet::CoupledLSTMBuilder
get_parameter_collection
()¶ Get parameters in LSTMBuilder.
-
dynet::CoupledLSTMBuilder
-
struct dynet
::
VanillaLSTMBuilder
¶ - #include <lstm.h>
VanillaLSTM allows the creation of a “standard” LSTM, ie with decoupled input and forget gates and no peephole connections.
This cell runs according to the following dynamics :
\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
Inherits from dynet::RNNBuilder
Public Functions
-
dynet::VanillaLSTMBuilder
VanillaLSTMBuilder
()¶ Default Constructor.
-
dynet::VanillaLSTMBuilder
VanillaLSTMBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model, bool ln_lstm = false, float forget_bias = 1.f)¶ Constructor for the VanillaLSTMBuilder.
- Parameters
layers
: Number of layersinput_dim
: Dimention of the input \(x_t\)hidden_dim
: Dimention of the hidden states \(h_t\) and \(c_t\)model
: ParameterCollection holding the parametersln_lstm
: Whether to use layer normalizationforget_bias
: value(float) to use as bias for the forget gate(default = 1.0)
-
void dynet::VanillaLSTMBuilder
set_dropout
(float d)¶ Set the dropout rates to a unique value.
This has the same effect as
set_dropout(d,d_h)
except that all the dropout rates are set to the same value.- Parameters
d
: Dropout rate to be applied on all of \(x,h\)
-
void dynet::VanillaLSTMBuilder
set_dropout
(float d, float d_r)¶ Set the dropout rates.
The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :
\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
- Parameters
d
: Dropout rate \(d_x\) for the input \(x_t\)d_h
: Dropout rate \(d_h\) for the output \(h_t\)
-
void dynet::VanillaLSTMBuilder
disable_dropout
()¶ Set all dropout rates to 0.
This is equivalent to
set_dropout(0)
orset_dropout(0,0,0)
-
void dynet::VanillaLSTMBuilder
set_dropout_masks
(unsigned batch_size = 1)¶ Set dropout masks at the beginning of a sequence for a specific batch size.
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
- Parameters
batch_size
: Batch size
-
ParameterCollection &dynet::VanillaLSTMBuilder
get_parameter_collection
()¶ Get parameters in VanillaLSTMBuilder.
- Return
- list of points to ParameterStorage objects
-
dynet::VanillaLSTMBuilder
-
struct dynet
::
CompactVanillaLSTMBuilder
¶ - #include <lstm.h>
VanillaLSTM allows the creation of a “standard” LSTM, ie with decoupled input and forget gates and no peephole connections.
This cell runs according to the following dynamics :
\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
Inherits from dynet::RNNBuilder
Public Functions
-
dynet::CompactVanillaLSTMBuilder
CompactVanillaLSTMBuilder
()¶ Default Constructor.
-
dynet::CompactVanillaLSTMBuilder
CompactVanillaLSTMBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model)¶ Constructor for the CompactVanillaLSTMBuilder.
- Parameters
layers
: Number of layersinput_dim
: Dimention of the input \(x_t\)hidden_dim
: Dimention of the hidden states \(h_t\) and \(c_t\)model
: ParameterCollection holding the parameters
-
void dynet::CompactVanillaLSTMBuilder
set_dropout
(float d)¶ Set the dropout rates to a unique value.
This has the same effect as
set_dropout(d,d_h)
except that all the dropout rates are set to the same value.- Parameters
d
: Dropout rate to be applied on all of \(x,h\)
-
void dynet::CompactVanillaLSTMBuilder
set_dropout
(float d, float d_r)¶ Set the dropout rates.
The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :
\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)
For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation
- Parameters
d
: Dropout rate \(d_x\) for the input \(x_t\)d_h
: Dropout rate \(d_h\) for the output \(h_t\)
-
void dynet::CompactVanillaLSTMBuilder
disable_dropout
()¶ Set all dropout rates to 0.
This is equivalent to
set_dropout(0)
orset_dropout(0,0,0)
-
void dynet::CompactVanillaLSTMBuilder
set_dropout_masks
(unsigned batch_size = 1)¶ Set dropout masks at the beginning of a sequence for a specific batch size.
If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element
- Parameters
batch_size
: Batch size
-
void dynet::CompactVanillaLSTMBuilder
set_weightnoise
(float std)¶ Get parameters in CompactVanillaLSTMBuilder.
- Return
- list of points to ParameterStorage objects
-
dynet::CompactVanillaLSTMBuilder
-
struct dynet
::
RNNBuilder
¶ - #include <rnn.h>
interface for constructing an RNN, LSTM, GRU, etc.
[long description]
Subclassed by dynet::CompactVanillaLSTMBuilder, dynet::CoupledLSTMBuilder, dynet::DeepLSTMBuilder, dynet::FastLSTMBuilder, dynet::GRUBuilder, dynet::SimpleRNNBuilder, dynet::TreeLSTMBuilder, dynet::VanillaLSTMBuilder
Public Functions
-
dynet::RNNBuilder
RNNBuilder
()¶ Default constructor.
-
RNNPointer dynet::RNNBuilder
state
() const¶ Get pointer to the current state.
- Return
- Pointer to the current state
-
void dynet::RNNBuilder
new_graph
(ComputationGraph &cg, bool update = true)¶ Initialize with new computation graph.
call this to reset the builder when you are working with a newly created ComputationGraph object
- Parameters
cg
: Computation graphupdate
: Update internal parameters while training
-
void dynet::RNNBuilder
start_new_sequence
(const std::vector<Expression> &h_0 = {})¶ Reset for new sequence.
call this before add_input and after new_graph, when starting a new sequence on the same hypergraph.
- Parameters
h_0
:h_0
is used to initialize hidden layers at timestep 0 to given values
-
Expression dynet::RNNBuilder
set_h
(const RNNPointer &prev, const std::vector<Expression> &h_new = {})¶ Explicitly set the output state of a node.
- Return
- The hidden representation of the deepest layer
- Parameters
prev
: Pointer to the previous stateh_new
: The new hidden state
-
Expression dynet::RNNBuilder
set_s
(const RNNPointer &prev, const std::vector<Expression> &s_new = {})¶ Set the internal state of a node (for lstms/grus)
For RNNs without internal states (SimpleRNN, GRU…), this has the same behaviour as
set_h
- Return
- The hidden representation of the deepest layer
- Parameters
prev
: Pointer to the previous states_new
: The new state. Can be{new_c[0],...,new_c[n]}
or{new_c[0],...,new_c[n], new_h[0],...,new_h[n]}
-
Expression dynet::RNNBuilder
add_input
(const Expression &x)¶ Add another timestep by reading in the variable x.
- Return
- The hidden representation of the deepest layer
- Parameters
x
: Input variable
-
Expression dynet::RNNBuilder
add_input
(const RNNPointer &prev, const Expression &x)¶ Add another timestep, with arbitrary recurrent connection.
This allows you to define a recurrent connection to
prev
rather than tohead[cur]
. This can be used to construct trees, implement beam search, etc.- Return
- The hidden representation of the deepest layer
- Parameters
prev
: Pointer to the previous statex
: Input variable
-
void dynet::RNNBuilder
rewind_one_step
()¶ Rewind the last timestep.
- this DOES NOT remove the variables from the computation graph, it just means the next time step will see a different previous state. You can rewind as many times as you want.
-
RNNPointer dynet::RNNBuilder
get_head
(const RNNPointer &p)¶ Return the RNN state that is the parent of
p
- This can be used in implementing complex structures such as trees, etc.
-
void dynet::RNNBuilder
set_dropout
(float d)¶ Set Dropout.
- Parameters
d
: Dropout rate
-
void dynet::RNNBuilder
disable_dropout
()¶ Disable Dropout.
In general, you should disable dropout at test time
-
virtual Expression dynet::RNNBuilder
back
() const = 0¶ Returns node (index) of most recent output.
- Return
- Node (index) of most recent output
-
virtual std::vector<Expression> dynet::RNNBuilder
final_h
() const = 0¶ Access the final output of each hidden layer.
- Return
- Final output of each hidden layer
-
virtual std::vector<Expression> dynet::RNNBuilder
get_h
(RNNPointer i) const = 0¶ Access the output of any hidden layer.
- Return
- Output of each hidden layer at the given step
- Parameters
i
: Pointer to the step which output you want to access
-
virtual std::vector<Expression> dynet::RNNBuilder
final_s
() const = 0¶ Access the final state of each hidden layer.
This returns the state of each hidden layer, in a format that can be used in start_new_sequence (i.e. including any internal cell for LSTMs and the likes)
- Return
- vector containing, if it exists, the list of final internal states, followed by the list of final outputs for each layer
-
virtual std::vector<Expression> dynet::RNNBuilder
get_s
(RNNPointer i) const = 0¶ Access the state of any hidden layer.
See
final_s
for details- Return
- Internal state of each hidden layer at the given step
- Parameters
i
: Pointer to the step which state you want to access
-
virtual unsigned dynet::RNNBuilder
num_h0_components
() const = 0¶ Number of components in
h_0
- Return
- Number of components in
h_0
-
virtual void dynet::RNNBuilder
copy
(const RNNBuilder ¶ms) = 0¶ Copy the parameters of another builder.
- Parameters
params
: RNNBuilder you want to copy parameters from.
-
dynet::RNNBuilder
-
struct dynet
::
SimpleRNNBuilder
¶ - #include <rnn.h>
This provides a builder for the simplest RNN with tanh nonlinearity.
The equation for this RNN is : \(h_t=\tanh(W_x x_t + W_h h_{t-1} + b)\)
Inherits from dynet::RNNBuilder
Public Functions
-
dynet::SimpleRNNBuilder
SimpleRNNBuilder
(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model, bool support_lags = false)¶ Builds a simple RNN.
- Parameters
layers
: Number of layersinput_dim
: Dimension of the inputhidden_dim
: Hidden layer (and output) sizemodel
: ParameterCollection holding the parameterssupport_lags
: Allow for auxiliary output?
-
Expression dynet::SimpleRNNBuilder
add_auxiliary_input
(const Expression &x, const Expression &aux)¶ Add auxiliary output.
Returns \(h_t=\tanh(W_x x_t + W_h h_{t-1} + W_y y + b)\) where \(y\) is an auxiliary output TODO : clarify
- Return
- The hidden representation of the deepest layer
- Parameters
x
: Input expressionaux
: Auxiliary output expression
-
dynet::SimpleRNNBuilder
-
struct dynet
::
TreeLSTMBuilder
¶ - #include <treelstm.h>
TreeLSTMBuilder is the base class for tree structured lstm builders.
Inherits from dynet::RNNBuilder
Subclassed by dynet::BidirectionalTreeLSTMBuilder, dynet::NaryTreeLSTMBuilder, dynet::UnidirectionalTreeLSTMBuilder
Public Functions
-
virtual Expression dynet::TreeLSTMBuilder
add_input
(int id, std::vector<int> children, const Expression &x) = 0¶ add input with given children at position id
if you did not call
set_num_elems
before, each successive id must be the previous id plus one and the children must all be smaller than id. If you usedset_num_elems
, id must be smaller than the number of elements and the children must have been already provided.- Parameters
id
: index wherex
should be storedchildren
: indices of the children for x
-
virtual void dynet::TreeLSTMBuilder
set_num_elements
(int num) = 0¶ Set the number of nodes in your tree in advance.
By default, input to a TreeLSTMBuilder needs to be in ascending order, i.e. when sequentializing the nodes, all leaves have to be first. If you know the number of elements beforehand, you can call this method to then place your nodes at arbitrary indices, e.g. because you already have a sequentialization that does not conform to the leaves-first requirement.
- Parameters
num
: desired size
-
virtual Expression dynet::TreeLSTMBuilder
-
struct dynet
::
NaryTreeLSTMBuilder
¶ - #include <treelstm.h>
Builds N-ary trees with a fixed upper bound of children. See “Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks” by Tai, Nary, and Manning (2015), section 3.2, for details on this model. http://arxiv.org/pdf/1503.00075v3.pdf.
Inherits from dynet::TreeLSTMBuilder
-
struct dynet
::
UnidirectionalTreeLSTMBuilder
¶ - #include <treelstm.h>
Builds a tree-LSTM which is recursively defined by a unidirectional LSTM over the node and its children representations.
Inherits from dynet::TreeLSTMBuilder
-
struct dynet
::
BidirectionalTreeLSTMBuilder
¶ - #include <treelstm.h>
Builds a tree-LSTM which is recursively defined by a Bidirectional LSTM over the node and its children representations.
Inherits from dynet::TreeLSTMBuilder
Softmax Builders¶
-
class dynet
::
SoftmaxBuilder
¶ - #include <cfsm-builder.h>
Interface for building softmax layers.
A softmax layer returns a probability distribution over \(C\) classes given a vector \(h\in\mathbb R^d\), with
\(p(c)\propto \exp(W_i^Th + b_i)\ \forall i\in\{1\ldots C\}\)
Where \(W\in \mathbb R^{C\times d}, b \in \mathbb R^C\)
Subclassed by dynet::ClassFactoredSoftmaxBuilder, dynet::HierarchicalSoftmaxBuilder, dynet::StandardSoftmaxBuilder
Public Functions
-
virtual void dynet::SoftmaxBuilder
new_graph
(ComputationGraph &cg, bool update = true) = 0¶ This initializes the parameters in the computation graph.
Call this once per ComputationGraph before any computation with the softmax
- Parameters
cg
: Computation graphupdate
: Whether to update the parameters
-
virtual Expression dynet::SoftmaxBuilder
neg_log_softmax
(const Expression &rep, unsigned classidx) = 0¶ Negative log probability of a class.
Given class \(c\) and vector \(h\), this returns \(-\log(p(c \mid h))\)
- Return
- \(-\log(p(\texttt{class} \mid \texttt{rep}))\)
- Parameters
rep
: vector expressionclass
: Class
-
virtual Expression dynet::SoftmaxBuilder
neg_log_softmax
(const Expression &rep, const std::vector<unsigned> &classidxs) = 0¶ Batched version of the former.
Returns a batched scalar
- Return
- \(-\log(p(\texttt{class}_b \mid \texttt{rep}_b))\) for each batch element \(b\)
- Parameters
rep
: Vector expression (batched)classes
: List of classes, one per batch element
-
virtual unsigned dynet::SoftmaxBuilder
sample
(const Expression &rep) = 0¶ Sample from the softmax distribution.
- Return
- Sampled class
- Parameters
rep
: Vector expression parametrizing the distribution
-
virtual Expression dynet::SoftmaxBuilder
full_log_distribution
(const Expression &rep) = 0¶ Returns an Expression representing a vector the size of the number of classes.
The ith dimension gives \(\log p(c_i | \texttt{rep})\). This function may be SLOW. Avoid if possible.
- Return
- Expression of the distribution
- Parameters
rep
: Vector expression parametrizing the distribution
-
virtual Expression dynet::SoftmaxBuilder
full_logits
(const Expression &rep) = 0¶ Returns the logits (before application of the softmax)
The ith dimension gives \(W_i^Th + b_i\)
- Return
- Expression for the logits
- Parameters
rep
: Vector expression parametrizing the distribution
-
virtual ParameterCollection &dynet::SoftmaxBuilder
get_parameter_collection
() = 0¶ Returns the ParameterCollection containing the softmax parameters.
- Return
- ParameterCollection
-
virtual void dynet::SoftmaxBuilder
-
class dynet
::
StandardSoftmaxBuilder
¶ - #include <cfsm-builder.h>
This class implements the standard Softmax.
Inherits from dynet::SoftmaxBuilder
Public Functions
-
dynet::StandardSoftmaxBuilder
StandardSoftmaxBuilder
(unsigned rep_dim, unsigned num_classes, ParameterCollection &pc, bool bias = true)¶ Constructs a softmaxbuilder.
This creates the parameters given the dimensions
- Parameters
rep_dim
: Dimension of the input vectorsnum_classes
: Number of classespc
: Parameter collectionbias
: Whether to use a bias vector or not
-
dynet::StandardSoftmaxBuilder
StandardSoftmaxBuilder
(Parameter &p_w, Parameter &p_b)¶ Builds a softmax layer with pre-existing parameters.
- Parameters
p_w
: Weight matrixp_b
: Bias vector
-
dynet::StandardSoftmaxBuilder
StandardSoftmaxBuilder
(Parameter &p_w)¶ Builds a softmax layer with pre-existing parameters (no bias)
- Parameters
p_w
: Weight matrix
-
dynet::StandardSoftmaxBuilder
-
class dynet
::
ClassFactoredSoftmaxBuilder
¶ - #include <cfsm-builder.h>
Class factored softmax.
Each class is separated into a subclass, ie \(p(i\mid h)=p(i\mid h, c) p(c\mid h)\) where \(c\) is a class and \(i\) a subclass
Inherits from dynet::SoftmaxBuilder
Public Functions
-
dynet::ClassFactoredSoftmaxBuilder
ClassFactoredSoftmaxBuilder
(unsigned rep_dim, const std::string &cluster_file, Dict &word_dict, ParameterCollection &pc, bool bias = true)¶ Constructor from file.
This constructs the CFSM from a file with lines of the following format
CLASSID word [freq]
For words for instance
- Parameters
rep_dim
: Dimension of the input vectorcluster_file
: File containing classesword_dict
: Dictionary for words (maps words to index)pc
: ParameterCollectionbias
: Whether to use a bias vector or not
-
Expression dynet::ClassFactoredSoftmaxBuilder
class_log_distribution
(const Expression &rep)¶ Get log distribution over classes.
- Return
- Vector of \(\log(p(c\mid \texttt{rep}))\)
- Parameters
rep
: Input vector
-
Expression dynet::ClassFactoredSoftmaxBuilder
class_logits
(const Expression &rep)¶ Get logits of classes.
- Return
- Logits
- Parameters
rep
: Input vector
-
Expression dynet::ClassFactoredSoftmaxBuilder
subclass_log_distribution
(const Expression &rep, unsigned clusteridx)¶ Get log distribution over subclasses of class.
- Return
- Vector of \(\log(p(i\mid c, \texttt{rep}))\)
- Parameters
rep
: Input vectorclusteridx
: Class index
-
Expression dynet::ClassFactoredSoftmaxBuilder
subclass_logits
(const Expression &rep, unsigned clusteridx)¶ Logits over subclasses of class.
- Return
- Logits
- Parameters
rep
: Input vectorclusteridx
: Class index
-
dynet::ClassFactoredSoftmaxBuilder
Optimizers¶
The various optimizers that you can use to tune your parameters
-
struct dynet
::
SimpleSGDTrainer
¶ - #include <training.h>
Stochastic gradient descent trainer.
This trainer performs stochastic gradient descent, the goto optimization procedure for neural networks. In the standard setting, the learning rate at epoch \(t\) is \(\eta_t=\frac{\eta_0}{1+\eta_{\mathrm{decay}}t}\)
Reference : reference needed
Inherits from dynet::Trainer
Public Functions
-
dynet::SimpleSGDTrainer
SimpleSGDTrainer
(ParameterCollection &m, real learning_rate = 0.1)¶ Constructor.
- Parameters
m
: ParameterCollection to be trainedlearning_rate
: Initial learning rate
-
dynet::SimpleSGDTrainer
-
struct dynet
::
CyclicalSGDTrainer
¶ - #include <training.h>
Cyclical learning rate SGD.
This trainer performs stochastic gradient descent with a cyclical learning rate as proposed in Smith, 2015.
This uses a triangular function with optional exponential decay.
More specifically, at each update, the learning rate \(\eta\) is updated according to :
\( \begin{split} \text{cycle} &= \left\lfloor 1 + \frac{\texttt{it}}{2 \times\texttt{step_size}} \right\rfloor\\ x &= \left\vert \frac{\texttt{it}}{\texttt{step_size}} - 2 \times \text{cycle} + 1\right\vert\\ \eta &= \eta_{\text{min}} + (\eta_{\text{max}} - \eta_{\text{min}}) \times \max(0, 1 - x) \times \gamma^{\texttt{it}}\\ \end{split} \)
Reference : Cyclical Learning Rates for Training Neural Networks
Inherits from dynet::Trainer
Public Functions
-
dynet::CyclicalSGDTrainer
CyclicalSGDTrainer
(ParameterCollection &m, float learning_rate_min = 0.01, float learning_rate_max = 0.1, float step_size = 2000, float gamma = 0.0, float edecay = 0.0)¶ Constructor.
- Parameters
m
: ParameterCollection to be trainedlearning_rate_min
: Lower learning ratelearning_rate_max
: Upper learning ratestep_size
: Period of the triangular function in number of iterations (not epochs). According to the original paper, this should be set around (2-8) x (training iterations in epoch)gamma
: Learning rate upper bound decay parameteredecay
: Learning rate decay parameter. Ideally you shouldn’t use this with cyclical learning rate since decay is already handled by \(\gamma\)
-
dynet::CyclicalSGDTrainer
-
struct dynet
::
MomentumSGDTrainer
¶ - #include <training.h>
Stochastic gradient descent with momentum.
This is a modified version of the SGD algorithm with momentum to stablize the gradient trajectory. The modified gradient is \(\theta_{t+1}=\mu\theta_{t}+\nabla_{t+1}\) where \(\mu\) is the momentum.
Reference : reference needed
Inherits from dynet::Trainer
Public Functions
-
dynet::MomentumSGDTrainer
MomentumSGDTrainer
(ParameterCollection &m, real learning_rate = 0.01, real mom = 0.9)¶ Constructor.
- Parameters
m
: ParameterCollection to be trainedlearning_rate
: Initial learning ratemom
: Momentum
-
dynet::MomentumSGDTrainer
-
struct dynet
::
AdagradTrainer
¶ - #include <training.h>
Adagrad optimizer.
The adagrad algorithm assigns a different learning rate to each parameter according to the following formula : \(\delta_\theta^{(t)}=-\frac{\eta_0}{\epsilon+\sum_{i=0}^{t-1}(\nabla_\theta^{(i)})^2}\nabla_\theta^{(t)}\)
Reference : Duchi et al., 2011
Inherits from dynet::Trainer
Public Functions
-
dynet::AdagradTrainer::AdagradTrainer(ParameterCollection & m, real learning_rate = 0.1, real eps = 1e-20)
Constructor.
- Parameters
m
: ParameterCollection to be trainedlearning_rate
: Initial learning rateeps
: Bias parameter \(\epsilon\) in the adagrad formula
-
-
struct dynet
::
AdadeltaTrainer
¶ - #include <training.h>
AdaDelta optimizer.
The AdaDelta optimizer is a variant of Adagrad where \(\frac{\eta_0}{\sqrt{\epsilon+\sum_{i=0}^{t-1}(\nabla_\theta^{(i)})^2}}\) is replaced by \(\frac{\sqrt{\epsilon+\sum_{i=0}^{t-1}\rho^{t-i-1}(1-\rho)(\delta_\theta^{(i)})^2}}{\sqrt{\epsilon+\sum_{i=0}^{t-1}(\nabla_\theta^{(i)})^2}}\), hence eliminating the need for an initial learning rate.
Reference : ADADELTA: An Adaptive Learning Rate Method
Inherits from dynet::Trainer
Public Functions
-
dynet::AdadeltaTrainer::AdadeltaTrainer(ParameterCollection & m, real eps = 1e-6, real rho = 0.95)
Constructor.
- Parameters
m
: ParameterCollection to be trainedeps
: Bias parameter \(\epsilon\) in the adagrad formularho
: Update parameter for the moving average of updates in the numerator
-
-
struct dynet
::
RMSPropTrainer
¶ - #include <training.h>
RMSProp optimizer.
The RMSProp optimizer is a variant of Adagrad where the squared sum of previous gradients is replaced with a moving average with parameter \(\rho\).
Reference : reference needed
Inherits from dynet::Trainer
Public Functions
-
dynet::RMSPropTrainer::RMSPropTrainer(ParameterCollection & m, real learning_rate = 0.1, real eps = 1e-20, real rho = 0.95)
Constructor.
- Parameters
m
: ParameterCollection to be trainedlearning_rate
: Initial learning rateeps
: Bias parameter \(\epsilon\) in the adagrad formularho
: Update parameter for the moving average (rho = 0
is equivalent to using Adagrad)
-
-
struct dynet
::
AdamTrainer
¶ - #include <training.h>
Adam optimizer.
The Adam optimizer is similar to RMSProp but uses unbiased estimates of the first and second moments of the gradient
Reference : Adam: A Method for Stochastic Optimization
Inherits from dynet::Trainer
Public Functions
-
dynet::AdamTrainer::AdamTrainer(ParameterCollection & m, float learning_rate = 0.001, float beta_1 = 0.9, float beta_2 = 0.999, float eps = 1e-8)
Constructor.
- Parameters
m
: ParameterCollection to be trainedlearning_rate
: Initial learning ratebeta_1
: Moving average parameter for the meanbeta_2
: Moving average parameter for the varianceeps
: Bias parameter \(\epsilon\)
-
-
struct dynet
::
EGTrainer
¶ - #include <training.h>
Exponentiated gradient optimizer with momentum and cyclical learning rate.
FIXME
Reference : FIXME
Inherits from dynet::Trainer
-
struct dynet
::
Trainer
¶ - #include <training.h>
General trainer struct.
Subclassed by dynet::AdadeltaTrainer, dynet::AdagradTrainer, dynet::AdamTrainer, dynet::CyclicalSGDTrainer, dynet::EGTrainer, dynet::MomentumSGDTrainer, dynet::RMSPropTrainer, dynet::SimpleSGDTrainer
Public Functions
-
dynet::Trainer
Trainer
(ParameterCollection &m, real learning_rate)¶ General constructor for a Trainer.
- Parameters
m
: ParameterCollection to be trainedlearning_rate
: Initial learning rate
-
void dynet::Trainer
update
()¶ Update parameters.
Update the parameters according to the appropriate update rule
-
void dynet::Trainer
update
(const std::vector<unsigned> &updated_params, const std::vector<unsigned> &updated_lookup_params)¶ Update subset of parameters.
Update some but not all of the parameters included in the model. This is the update_subset() function in the Python bindings. The parameters to be updated are specified by index, which can be found for Parameter and LookupParameter objects through the “index” variable (or the get_index() function in the Python bindings).
- Parameters
updated_params
: The parameter indices to be updatedupdated_lookup_params
: The lookup parameter indices to be updated
-
virtual void dynet::Trainer
restart
() = 0¶ Restarts the optimizer.
Clears all momentum values and assimilate (if applicable)
Public Members
-
bool dynet::Trainer
sparse_updates_enabled
¶ Whether to perform sparse updates.
DyNet trainers support two types of updates for lookup parameters, sparse and dense. Sparse updates are the default. They have the potential to be faster, as they only touch the parameters that have non-zero gradients. However, they may not always be faster (particulary on GPU with mini-batch training), and are not precisely numerically correct for some update rules such as MomentumTrainer and AdamTrainer. Thus, if you set this variable to false, the trainer will perform dense updates and be precisely correct, and maybe faster sometimes.
-
dynet::Trainer
DyNet Examples¶
This is a set of common (and less common) models and their implementation in Dynet (C++ and Python).
Some examples have only one of the two languages, or lack documentation, in which case we welcome contributions for the other. Documentation should include directions on how to download standard datasets, run these examples on these datasets, and calculate standard measures of accuracy etc. A good example of a simple README is in the mnist directory. Contributions to adding these are welcome!
Note that these examples are meant to be minimal examples, not necessarily the state of the art. Concurrently, we are working on creating a state-of-the-art model repository. In the mean time, you can browse the many research projects that use DyNet and find one that fits your needs.
Main Model Examples¶
These examples are of common models and are intended to be relatively well maintained.
- XOR: The simplest possible model, solving xor (C++/Python).
- MNIST: An example of MNIST image classification using a simple multi-layer perceptron (C++).
- RNN Language Model: A recurrent neural network language model (C++/Python).
- Sequence-to-sequence Model: Sequence to sequence models using standard encoder decoders, or attention (C++/Python).
- BiLSTM Tagger: Models that do sequence labeling with BiLSTM feature extractors (C++/Python).
- Text Categorization: Models for text categorization (C++/Python).
- Word Embedding: Models for word embedding (C++).
Functionality Examples¶
These examples demonstrate how to take advantage of various types of functionality of DyNet.
- Batching: How to use mini-batch training (C++/Python).
- Automatic Batching: How to use DyNet’s automatic batching functionality (C++).
- Devices: How to use DyNet on CPUs, GPUs, or multiple devices (C++/Python).
- Multiprocessing: DyNet’s multiprocessing functionality for training models in parallel (C++).
- TensorBoard: How to use DyNet with TensorBoard through PyCrayon (Python).
- Reading/Writing: How to read/write models (C++).
- Jupyter Tutorials: Various tutorials in the form of Jupyter notebooks (Python).
Auxiliary Model Examples¶
These are somewhat less common and not necessarily well supported, but still may be useful for some people.
- Document Classification: An example of modeling documents with a hierarchical model (C++).
- Feed Forward Language Model: A model for predicting the next word using a feed forward network (C++).
- Poisson Regression: A model for predicting an integer using Poisson regression given a sentence (C++).
- Sentence Embedding: A model for learning sentence embeddings from parallel data, with negative sampling (C++).
- Variational Auto-encoders: Examples using variational auto-encoders (C++).
- Noise Contrastive Estimation: Examples using noise contrastive estimation to speed training (C++).
- Softmax Builders: Examples of how to use other types of softmax functions, including class factored softmax (C++).
- Segmental RNNs: A segmental RNN model (C++).
Mode advanced topics are below:
Minibatching¶
Minibatching Overview¶
Minibatching takes multiple training examples and groups them together to be processed simultaneously, often allowing for large gains in computational efficiency due to the fact that modern hardware (particularly GPUs, but also CPUs) have very efficient vector processing instructions that can be exploited with appropriately structured inputs.
As shown in the figure below, common examples of this in neural networks include grouping together matrix-vector multiplies from multiple examples into a single matrix-matrix multiply, or performing an element-wise operation (such as tanh
) over multiple vectors at the same time as opposed to processing single vectors individually.

In most neural network toolkits, mini-batching is largely left to the user, with a bit of help from the toolkit. This is usually done by adding an additional dimension to the tensor that they are interested in processing, and ensuring that all operations consider this dimension when performing processing. This adds some cognitive load, as the user must keep track of this extra batch dimension in all their calculations, and also ensure that they use the correct ordering of the batch dimensions to achieve maximum computational efficiency. Users must also be careful when performing operations that combine batched and unbatched elements (such as batched hidden states of a neural network and unbatched parameter matrices or vectors), in which case they must concatenate vectors into batches, or “broadcast” the unbatched element, duplicating it along the batch dimension to ensure that there are no illegal dimension mismatches.
DyNet hides most of this complexity from the user through easy-to-use mini-batching functionality that can either be completely automatic, or
Automatic Mini-batching¶
If you want to get many of the benefits of mini-batching without doing any work, you can use DyNet’s automatic mini-batching functionality.

This functionality can be enabled by enabling the --dynet-autobatch 1
command line option, and if this is enabled, DyNet will automatically attempt to find operations that can be batched together to improve efficiency.
To take full advantage of this, you will want to create a big computation graph that represents multiple training examples by simply iterating over the multiple training examples as follows:
for minibatch in training_data:
dy.renew_cg()
losses = []
for x, y in minibatch:
l = calculate_my_loss(x, y)
losses.append(l)
loss = dy.esum(losses)
loss.forward()
loss.backward()
trainer.update()
This is nice because the calculate_my_loss function can be arbitrarily complex and doesn’t have to have the same structure across sentences. A full example of mini-batching in action for a tree-structured neural network model can be found here for C++ and Python.
Manual Mini-batching¶
In easy-to-minibatch networks where the structure remains the same across multiple sentences, it is possible to get some further gains by performing manual mini-batching, similarly to what you do in other toolkits. Even in this case, DyNet hides much of this complexity from the user through the use of specially designed batching operations which treat the number of mini-batch elements not as another standard dimension, but as a special dimension with particular semantics. Broadcasting is done behind the scenes by each operation implemented in DyNet, and thus the user must only think about inputting multiple pieces of data for each batch, and calculating losses using multiple labels.
First, let’s take a look at a non-minibatched example using the Python API.
In this example, we look up word embeddings word_1
and word_2
using lookup parameters E
.
We then perform an affine transform using weights W
and bias b
, and perform a softmax.
Finally, we calculate the loss given the true label out_label
.
# in_words is a tuple (word_1, word_2)
# out_label is an output label
word_1 = E[in_words[0]]
word_2 = E[in_words[1]]
scores_sym = W*dy.concatenate([word_1, word_2])+b
loss_sym = dy.pickneglogsoftmax(scores_sym, out_label)
Next, let’s take a look at the mini-batched version:
# in_words is a list [(word_{1,1}, word_{1,2}), (word_{2,1}, word_{2,2}), ...]
# out_labels is a list of output labels [label_1, label_2, ...]
word_1_batch = dy.lookup_batch(E, [x[0] for x in in_words])
word_2_batch = dy.lookup_batch(E, [x[1] for x in in_words])
scores_sym = W*dy.concatenate([word_1_batch, word_2_batch])+b
loss_sym = dy.sum_batches( dy.pickneglogsoftmax_batch(scores_sym, out_labels) )
We can see there are only 4 major changes: the word IDs need to be transformed into lists of IDs instead of a single ID, we need to call lookup_batch
instead of the standard lookup, we need to call pickneglogsoftmax_batch
instead of the unbatched version, and we need to call sum_batches
at the end to sum the loss from all the batches.
A full example of mini-batching in action for a recurrent neural language model can be found here for C++ and Python.
The Mini-batch Dimension¶
The way DyNet handles this is by using a special privileged “mini-batch element” dimension, which indicates the number of training examples in the mini-batch. To give an example from the C++ API, we can declare a Dim
object in C++
Dim d({2,4,8}, 16)
or Python
d = Dim([2,4,8], 16)
Here, 2,4,8
are the dimensions of the data in the tensor for each example, while 16
is the number of examples in the mini-batch. When we print out the dimensions (for example when calling the print_graphviz()
functionality for debugging, this will be print as {2,4,8x16}
.
Mini-batched Functions¶
For the great majority of standard operations, things should work seamlessly for minibatched elements. The one condition is that all inputs must have either one mini-batch element only, or the same number of mini-batch elements. So a binary function f(x,y)
could take inputs where the number of minibatch elements in x/y
are 1/1
, 4/1
, 1/4
, or 4/4
respectively. However, it is not possible to have different non-one numbers of minibatch elements, such as x/y
having minibatch sizes of 2/4
.
There are some operations where we need to explicitly think about batching, mostly on the input and output sides of the graph. These include input operations:
lookup()
(C++) andlookup_batch()
(Python): Performs lookup over a vector of input IDs, where each input ID is an element of the mini-batch.input()
: C++ input can specify aDim
object that is mini-batched. In Python, directly adding batched input is not supported yet, but there is a workaround <https://github.com/clab/dynet/issues/175> usingreshape()
.
Loss calculation operations:
pickneglogsoftmax()
(C++) andpickneglogsoftmax_batch()
(Python): Calculates the negative log softmax loss over multiple batch elements.hinge()
(C++): Similarly, calculate hinge loss over multiple elements.
Manipulation operations:
reshape()
: Can be used to reshape into tensors with a batch element of more than one.pick()
(C++) andpick_batch()
(Python): Picks an element for each of the mini-batch elements.sum_batches()
: Will sum together all of the values in the batch. This is often used to sum together the loss function befor performing the backward step.
Multi-processing¶
In addition to minibatch support, the DyNet C++ API also supports training models using many CPU cores (Python support is pending). This is particularly useful when performing training of networks that are not conducive to simple mini-batching, such as tree-structured networks.
DyNet abstracts most of the behind-the-scenes grit from the user.
The user defines a function to be called for each datum in the training data set, and passes this function, along with an array of data, to DyNet.
Internally, DyNet launches a pool of training processes and automatically handles passing data examples to each worker.
Each worker process individually processes a datum, computing the results of the forward and backward passes, computes gradients with respect to each parameter, and passes these results back to the parent process via a shared memory variable.
Whenever the parent process, which is also processing data, completes a gradient computation, it averages all of the gradients currently in the shared memory gradient storage and updates all parameters with respect to that average gradient.
In this way running training on n
cores is similar to training with a stochastic minibatch size with expected value of approximately n
.
This method is quite efficient, achieving nearly linear speedups with increasing numbers of cores, due to its lockless nature.
Examples of how to use the multi-processing API can be found in the xor-mp
and rnnlm-mp
sections of the examples/cpp
directory.
Unorthodox Design¶
There are a couple design decisions about DyNet that are different from the way things are implemented in other libraries, or different from the way you might expect things to be implemented. The items below are a list of these unorthodox design decisions, which you should read to avoid being surprised. We also try to give some justification for these decisions (although we realize that this is not the only right way to do things).
Sparse Updates¶
By default, DyNet parameter optimizers perform sparse updates over
LookupParameters
. This means that if you have a LookupParameters
object, use a certain subset of indices, then perform a parameter update, the
optimizer will loop over the used subset, and not perform any updates over
the unused values. This can improve efficiency in some cases: e.g. if you have
embeddings for a vocabulary of 100,000 words and you only use 5 of them in a
particular update, this will avoid doing updates over all 100,000. However,
there are two things to be careful of. First, this means that some update rules
such as ones using momentum such as MomentumSGDTrainer
and AdamTrainer
are not strictly correct (these could be made correct with some effort, but
this would complicate the programming interface, which we have opted against).
Also, on GPUs, because large operations are
relatively cheap, it can sometimes be faster to just perform a single operation
over all of the parameters, as opposed to multiple small operations. In this
case, you can set the sparse_updates_enabled
variable of your Trainer
to false
, and DyNet will perform a standard dense update, which is
guaranteed to be exactly correct, and potentially faster on GPU.
Weight Decay¶
As described in the Command Line Options, weight decay is implemented
through the option --dynet-weight-decay
. If this value is set to wd
,
each parameter in the model is multiplied by (1-wd)
after every parameter
update. This weight decay is similar to L2 regularization, and is equivalent in
the case of using simple SGD (SimpleSGDTrainer
), but it is not the same
when using any other optimizers such as AdagradTrainer
or AdamTrainer
.
You can still try to use weight decay with these optimizers, and it might work,
but if you really want to correctly apply L2 regularization with these
optimizers, you will have to directly calculate the L2 norm of each of the
parameters and add it to the objective function before performing your update.
Minibatching Implementation¶
Minibatching in DyNet is different than how it is implemented in other libraries. In other libraries, you can create minibatches by explicitly adding another dimension to each of the variables that you want to process, and managing them yourself. Instead, DyNet provides special Operations that allow you to perform input, lookup, or loss calculation over mini-batched input, then DyNet will handle the rest. The programming paradigm is a bit different from other toolkits, and may take a bit of getting used to, but is often more convenient once you’re used to it.
LSTM Implementation¶
The implementation of LSTMs in LSTMBuilder
is not the canonical
implementation, but an implementation using coupled input and forget gates, as
described in “LSTM: A Search Space Odyssey” (https://arxiv.org/abs/1503.04069).
In other words, if the value of the input gate is i, the forget gate is 1-i.
This reduces the number of parameters in the model and speeds training a little,
and in many cases the accuracy is the same or better. If you want to try the
standard version of the LSTM, use the VanillaLSTMBuilder
class.
Dropout Scaling¶
When using dropout to help prevent overfitting, dropout is generally applied at training time, then at test time all the nodes in the neural net are used to make the final decision, increasing robustness. However, because there is a disconnect between the number of nodes being used in each situation, it is important to scale the values of the output to ensure that they match in both situations. There are two ways to do this:
- Vanilla Dropout: At training time, perform dropout with probability
p
. At test time, scale the outputs of each node byp
. - Inverted Dropout: At training time, perform dropout with probability
p
, and scale the outputs by1/p
. At test time, use the outputs as-is.
The first is perhaps more common, but the second is convenient, because we only need to think about dropout at training time, and thus DyNet opts to use the latter. See here for more details on these two methods.
Projects using DyNet¶
DyNet works for your complex neural networks¶
DyNet was designed from the ground up to be fast for neural networks with complex structure or control flow such as the ones that you need to handle tree or graph structures, or perform reinforcement learning or training with exploration. Below are some examples of full systems that use DyNet to handle their dynamic neural network needs.
Syntactic Parsing¶
Parsing is currently the most prominent scenario in which DyNet has been used, and DyNet was behind the development of a number of methods such as stack LSTMs, bi-directional LSTM feature extractors for dependency parsing, recurrent neural network grammars, and hierarchical tree LSTMs. A submission to the CoNLL shared task on dependency parsing using DyNet registered second place, and was nearly an order of magnitude faster than other submissions.
Machine Translation¶
DyNet is the backend chosen by a number of machine translation systems such as Mantis, Lamtram, nmtkit, and xnmt. It has powered the development of models that use complicated structures, such as lattice-to-sequence models.
Speech Recognition¶
DyNet powers the “Listen, Attend, and Spell” style models in xnmt. It has also been used to implement acoustic models using connectionist temporal classification (CTC).
Graph Parsing¶
DyNet powers the transition based UCCA parser that can predict graph structures from text.
Language Modeling¶
DyNet has been used in the development of hybrid neural/n-gram language models, and generative syntactic language models.
Tagging¶
DyNet supports applications to tagging for named entity recognition, semantic role labeling, punctuation prediction, and has been used in the creation of new architectures such as segmental recurrent neural networks.
Morphology¶
DyNet has been used in seminal work for morphological inflection generation and inflection generation with hard attention.
And we welcome your contributions!
Contributing to Dynet¶
DyNet is an open source project that is only possible because of contributions from users like you! We greatly welcome any problems, whether they are to correct a bug or add a feature, and these should be made through a pull request on the DyNet github page.
Below are some guidelines to guarantee consistency.
Coding Tips and Style¶
Coding Practices¶
Testing:
Before committing any code, tests should be run to make sure that the new code didn’t break anything.
This can be done by using the make test
command.
It is also highly recommended that you add unit tests for any new functionality.
Unit tests are implemented in the tests
directory.
When making a bug fix, you can add a test that broke before the fix but passes afterwards.
That being said, tests are not an absolute requirement, so if you have a contribution but aren’t sure how to do tests, please don’t let this stop you from contributing.
Coding Style Conventions¶
DyNet (the main version in C++) has certain coding style standards:
Overall Philosophy: DyNet is designed to minimize the computational overhead when creating networks. Try to avoid doing slow things like creating objects or copying memory in places that will be called frequently during computation graph construction.
Function Names: Function names are written in “snake_case”.
const: Always use const if the input to a function is constant.
Pointer vs. Reference: When writing functions, use the following guidelines (quoted from here):
- Only pass a value by pointer if the value 0/NULL is a valid input in the current context.
- If a function argument is an out-value, then pass it by reference.
- Choose “pass by value” over “pass by const reference” only if the value is a POD (Plain Old Datastructure) or small enough (memory-wise) or in other ways cheap enough (time-wise) to copy.
Error handling: The C++ core of DyNet provides a mechanism for error handling that
should be used in all code. It consists of 3 macros as follows (included in globals.h
):
DYNET_INVALID_ARG(msg)
: This is used to throw an error that is triggered when a user passes an invalid argument to one of the functions.DYNET_RUNTIME_ERR(msg)
: This is used to throw an error that could be triggered by a user, but is not the result of an invalid argument. For example, it could be used when something is not implemented yet, or when the program dies due to lack of memory, etc.DYNET_ASSERT(expr,msg)
: This is to be used to check things that should only happen due to a programming error within DyNet itself, and should never be triggered by a user.expr
is a condition, andmsg
is a message explaining the exception, withostream
-style formatting.
Coding Tips/How To¶
Adding New Operations¶
One of the most common things that one will want to do to modify DyNet is to add a new operation to calculate a new function. You can find more information on how to do so at the end of the tutorial slides here (note that some file names are old).
Taking a look at the existing operations in the nodes-XXX.h
and nodes-XXX.cc
files
will be the best guide in creating new operations. Here are some fine-grained tips for
those that want to dive into the process.
fx
is a pointer to the (preallocated) location for the result of forward to be storedfx
is not initialized, so after calling forwardfx
must contain the correct answer- dEdxi MUST ACCUMULATE a result since multiple calls to forward may depend on
the same
x_i
. Even, e.g., Identity must be implemented asdEdx1 += dEdf
. - scalars results of forward are placed in
fx.v[0]
- DyNet manages its own memory, not Eigen, and it is configured with the EIGEN_NO_MALLOC option. If you get an error about Eigen attempting to allocate memory, it is (probably) because of an implicit creation of a temporary variable. If you really do need a temporary variable, its capacity must be requested by Node::aux_storage_size
And here are some notes on debugging problems with new operations
- fx is uninitialized when forward is called- are you relying on it being 0?
- dEdxi must accumulate (see point 3 above!)
Decreasing Compile Time¶
DyNet has a GPU_NUMFILES
option that allows you to specify the number of separate
files that are compiled when compiling for GPU. This is set to 4 on Linux/Mac and 1 on
Windows (because MSVC doesn’t support parallel compilation for GPU code). If you’re
developing new operations for DyNet, it might be a good idea to set GPU_NUMFILES
to zero, which will result in all nodes-XXX.cc
files being compiled separately.
In this case, if you change a single file, it will only recompile that file instead
of recompiling all of the code in all of the nodes-XXX.cc
files.
Documentation¶
Dynet uses Doxygen for commenting the code and Sphinx for the general documentation.
If you’re only documenting features you don’t need to concern yourself with Sphinx, your doxygen comments will be integrated in the documentation automatically.
Doxygen guidelines¶
Please document any publicly accessible function you write using the doxygen syntax.
You can see examples in the training file. The most important thing is to use /*
style comments and \command
style commands.
For ease of access the documentation is divided into groups. For now the groups are optimizers and operations. If you implement a function that falls into one of these groups, add \ingroup [group name]
at the beginning of your comment block.
If you want to create a group, use \defgroup [group-name]
at the beginning of your file. Then create a file for this group in sphinx (see next section).
Important : You can use latex in doxygen comments with the syntax \f$ \f$
. For some reason since readthedocs updated their version of sphinx \f[ \f]
doesn’t work anymore so don’t use it it breaks the build.
Sphinx guidelines¶
The sphinx source files are located in doc/source
. They describe the documentation’s organization using the reStructuredText Markup language.
Although reStructuredText is more powerful than Markdown it might feel less intuitive, especially when writing long documents. If needs be you can write your doc in Markdown and convert it using Pandoc.
For a tutorial on Sphinx see their tutorial.
Doxygen generated XML is integrated in sphinx files using the Breathe module. The only breathe command used now is doxygengroup
. You shouldn’t used commands for individual classes/functions/structs without a good reason. Most information should be put in the doxygen comments.
Building the docs¶
The documentation is automatically rebuilt by ReadTheDocs each time you push on Github.
If you want to build the documentation locally you’ll need to install doxygen, sphinx and breathe and then run build_doc.sh
from the doc
folder.
Contributors¶
DyNet was started at Carnegie Mellon University by Chris Dyer (now at DeepMind), and the project is now led by Graham Neubig (CMU) and Yoav Goldberg (Bar Ilan University). It relies contributors from a wide variety of institutions, including:






and many others!
Citing/Logos¶
If you use DyNet for research, please cite our technical report as follows:
@article{dynet,
title={DyNet: The Dynamic Neural Network Toolkit},
author={Graham Neubig and Chris Dyer and Yoav Goldberg and Austin Matthews and Waleed Ammar and Antonios Anastasopoulos and Miguel Ballesteros and David Chiang and Daniel Clothiaux and Trevor Cohn and Kevin Duh and Manaal Faruqui and Cynthia Gan and Dan Garrette and Yangfeng Ji and Lingpeng Kong and Adhiguna Kuncoro and Gaurav Kumar and Chaitanya Malaviya and Paul Michel and Yusuke Oda and Matthew Richardson and Naomi Saphra and Swabha Swayamdipta and Pengcheng Yin},
journal={arXiv preprint arXiv:1701.03980},
year={2017}
}
If you want to talk about DyNet in presentations, blog posts, etc., feel free to use one of the logos below!


