Operations

Operation Interface

The following functions define DyNet “Expressions,” which are used as an interface to the various functions that can be used to build DyNet computation graphs. Expressions for each specific function are listed below.

struct dynet::Expression
#include <expr.h>

Expressions are the building block of a Dynet computation graph.

[long description]

Public Functions

dynet::ExpressionExpression(ComputationGraph *pg, VariableIndex i)

Base expression constructor.

Used when creating operations

Parameters
  • pg: Pointer to the computation graph
  • i: Variable index

const Tensor &dynet::Expressionvalue() const

Get value of the expression.

Throws a tuntime_error exception if no computation graph is available

Return
Value of the expression as a tensor

const Tensor &dynet::Expressiongradient() const

Get gradient of the expression.

Throws a tuntime_error exception if no computation graph is available

Make sure to call backward on a downstream expression before calling this.

If the expression is a constant expression (meaning it’s not a function of a parameter), dynet won’t compute it’s gradient for the sake of efficiency. You need to manually force the gradient computation by adding the agument full=true to backward

Return
Value of the expression as a tensor

const Dim &dynet::Expressiondim() const

Get dimension of the expression.

Throws a tuntime_error exception if no computation graph is available

Return
Dimension of the expression

Input Operations

These operations allow you to input something into the computation graph, either simple scalar/vector/matrix inputs from floats, or parameter inputs from a DyNet parameter object. They all requre passing a computation graph as input so you know which graph is being used for this particular calculation.

Expression dynet::input(ComputationGraph &g, real s, Device *device = dynet::default_device)

Scalar input.

Create an expression that represents the scalar value s

Return
An expression representing s
Parameters
  • g: Computation graph
  • s: Real number
  • device: The place device for the input value, default_device by default

Expression dynet::input(ComputationGraph &g, const real *ps, Device *device = dynet::default_device)

Modifiable scalar input.

Create an expression that represents the scalar value *ps. If *ps is changed and the computation graph recalculated, the next forward pass will reflect the new value.

Return
An expression representing *ps
Parameters
  • g: Computation graph
  • ps: Real number pointer
  • device: The place device for the input value, default_device by default

Expression dynet::input(ComputationGraph &g, const Dim &d, const std::vector<float> &data, Device *device = dynet::default_device)

Vector/matrix/tensor input.

Create an expression that represents a vector, matrix, or tensor input. The dimensions of the input are defined by d. So for example > input(g,{50},data): will result in a 50-length vector > input(g,{50,30},data): will result in a 50x30 matrix and so on, for an arbitrary number of dimensions. This function can also be used to import minibatched inputs. For example, if we have 10 examples in a minibatch, each with size 50x30, then we call > input(g,Dim({50,30},10),data) The data vector “data” will contain the values used to fill the input, in column-major format. The length must add to the product of all dimensions in d.

Return
An expression representing data
Parameters
  • g: Computation graph
  • d: Dimension of the input matrix
  • data: A vector of data points
  • device: The place device for the input value, default_device by default

Expression dynet::input(ComputationGraph &g, const Dim &d, const std::vector<float> *pdata, Device *device = dynet::default_device)

Updatable vector/matrix/tensor input.

Similarly to input that takes a vector reference, input a vector, matrix, or tensor input. Because we pass the pointer, the data can be updated.

Return
An expression representing *pdata
Parameters
  • g: Computation graph
  • d: Dimension of the input matrix
  • pdata: A pointer to an (updatable) vector of data points
  • device: The place device for the input value, default_device by default

Expression dynet::input(ComputationGraph &g, const Dim &d, const std::vector<unsigned int> &ids, const std::vector<float> &data, float defdata = 0.f, Device *device = dynet::default_device)

Sparse vector input.

This operation takes input as a sparse matrix of index/value pairs. It is exactly the same as the standard input via vector reference, but sets all non-specified values to “defdata” and resets all others to the appropriate input values.

Return
An expression representing data
Parameters
  • g: Computation graph
  • d: Dimension of the input matrix
  • ids: The indexes of the data points to update
  • data: The data points corresponding to each index
  • defdata: The default data with which to set the unspecified data points
  • device: The place device for the input value, default_device by default

Expression dynet::parameter(ComputationGraph &g, Parameter p)

Load parameter.

Load parameters into the computation graph.

Return
An expression representing p
Parameters
  • g: Computation graph
  • p: Parameter object to load

Expression dynet::parameter(ComputationGraph &g, LookupParameter lp)

Load lookup parameter.

Load a full tensor of lookup parameters into the computation graph. Normally lookup parameters are accessed by using the lookup() function to grab a single element. However, in some cases we’ll want to access all of the parameters in the entire set of lookup parameters for some reason. In this case you can use this function. In this case, the first dimensions in the returned tensor will be equivalent to the dimensions that we would get if we get calling the lookup() function, and the size of the final dimension will be equal to the size of the vocabulary.

Return
An expression representing lp
Parameters

Expression dynet::const_parameter(ComputationGraph &g, Parameter p)

Load constant parameters.

Load parameters into the computation graph, but prevent them from being updated when performing parameter update.

Return
An expression representing the constant p
Parameters
  • g: Computation graph
  • p: Parameter object to load

Expression dynet::const_parameter(ComputationGraph &g, LookupParameter lp)

Load constant lookup parameters.

Load lookup parameters into the computation graph, but prevent them from being updated when performing parameter update.

Return
An expression representing the constant lp
Parameters

Expression dynet::lookup(ComputationGraph &g, LookupParameter p, unsigned index)

Look up parameter.

Look up parameters according to an index, and load them into the computation graph.

Return
An expression representing p[index]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • index: Index of the parameters within p

Expression dynet::lookup(ComputationGraph &g, LookupParameter p, const unsigned *pindex)

Look up parameters with modifiable index.

Look up parameters according to the *pindex, and load them into the computation graph. When *pindex changes, on the next computation of forward() the values will change.

Return
An expression representing p[*pindex]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • pindex: Pointer index of the parameters within p

Expression dynet::const_lookup(ComputationGraph &g, LookupParameter p, unsigned index)

Look up parameter.

Look up parameters according to an index, and load them into the computation graph. Do not perform gradient update on the parameters.

Return
A constant expression representing p[index]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • index: Index of the parameters within p

Expression dynet::const_lookup(ComputationGraph &g, LookupParameter p, const unsigned *pindex)

Constant lookup parameters with modifiable index.

Look up parameters according to the *pindex, and load them into the computation graph. When *pindex changes, on the next computation of forward() the values will change. However, gradient updates will not be performend.

Return
A constant expression representing p[*pindex]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • pindex: Pointer index of the parameters within p

Expression dynet::lookup(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> &indices)

Look up parameters.

The mini-batched version of lookup. The resulting expression will be a mini-batch of parameters, where the “i”th element of the batch corresponds to the parameters at the position specified by the “i”th element of “indices”

Return
An expression with the “i”th batch element representing p[indices[i]]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • indices: Index of the parameters at each position in the batch

Expression dynet::lookup(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> *pindices)

Look up parameters.

The mini-batched version of lookup with modifiable parameter indices.

Return
An expression with the “i”th batch element representing p[*pindices[i]]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • pindices: Pointer to lookup indices

Expression dynet::const_lookup(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> &indices)

Look up parameters.

Mini-batched lookup that will not update the parameters.

Return
A constant expression with the “i”th batch element representing p[indices[i]]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • indices: Lookup indices

Expression dynet::const_lookup(ComputationGraph &g, LookupParameter p, const std::vector<unsigned> *pindices)

Look up parameters.

Mini-batched lookup that will not update the parameters, with modifiable indices.

Return
A constant expression with the “i”th batch element representing p[*pindices[i]]
Parameters
  • g: Computation graph
  • p: LookupParameter object from which to load
  • pindices: Lookup index pointers.

Expression dynet::zeros(ComputationGraph &g, const Dim &d)

Create an input full of zeros.

Create an input full of zeros, sized according to dimensions d.

Return
A d dimensioned zero tensor
Parameters
  • g: Computation graph
  • d: The dimensions of the input

Expression dynet::ones(ComputationGraph &g, const Dim &d)

Create an input full of ones.

Create an input full of ones, sized according to dimensions d.

Return
A d dimensioned tensor of ones
Parameters
  • g: Computation graph
  • d: The dimensions of the input

Expression dynet::constant(ComputationGraph &g, const Dim &d, float val)

Create an input with one constant value.

Create an input full of val, sized according to dimensions d.

Return
A d dimensioned tensor filled with value val
Parameters
  • g: Computation graph
  • d: The dimensions of the input
  • val: The value of the input

Expression dynet::random_normal(ComputationGraph &g, const Dim &d, float mean = 0.f, float stddev = 1.0)

Create a random normal vector.

Create a vector distributed according to normal distribution with specified mean and standard deviation.

Return
A “d” dimensioned normally distributed vector
Parameters
  • g: Computation graph
  • d: The dimensions of the input
  • mean: The mean of the distribution (default: 0.0)
  • stddev: The standard deviation of the distribution (default: 1.0)

Expression dynet::random_bernoulli(ComputationGraph &g, const Dim &d, real p, real scale = 1.0f)

Create a random bernoulli vector.

Create a vector distributed according to bernoulli distribution with parameter p.

Return
A “d” dimensioned bernoulli distributed vector
Parameters
  • g: Computation graph
  • d: The dimensions of the input
  • p: The bernoulli p parameter
  • scale: A scaling factor for the output (“active” elements will receive this value)

Expression dynet::random_uniform(ComputationGraph &g, const Dim &d, real left, real right)

Create a random uniform vector.

Create a vector distributed according to uniform distribution with boundaries left and right.

Return
A “d” dimensioned uniform distributed vector
Parameters
  • g: Computation graph
  • d: The dimensions of the input
  • left: The left boundary
  • right: The right boundary

Expression dynet::random_gumbel(ComputationGraph &g, const Dim &d, real mu = 0.0, real beta = 1.0)

Create a random Gumbel sampled vector.

Create a vector distributed according to a Gumbel distribution with the specified parameters. (Currently only the defaults of mu=0.0 and beta=1.0 supported.

Return
A “d” dimensioned Gumbel distributed vector
Parameters
  • g: Computation graph
  • d: The dimensions of the input
  • mu: The mu parameter
  • beta: The beta parameter

Arithmetic Operations

These operations perform basic arithemetic over values in the graph.

Expression dynet::operator-(const Expression &x)

Negation.

Negate the passed argument.

Return
The negation of x
Parameters
  • x: An input expression

Expression dynet::operator+(const Expression &x, const Expression &y)

Expression addition.

Add two expressions of the same dimensions.

Return
The sum of x and y
Parameters
  • x: The first input
  • y: The second input

Expression dynet::operator+(const Expression &x, real y)

Scalar addition.

Add a scalar to an expression

Return
An expression equal to x, with every component increased by y
Parameters
  • x: The expression
  • y: The scalar

Expression dynet::operator+(real x, const Expression &y)

Scalar addition.

Add a scalar to an expression

Return
An expression equal to y, with every component increased by x
Parameters
  • x: The scalar
  • y: The expression

Expression dynet::operator-(const Expression &x, const Expression &y)

Expression subtraction.

Subtract one expression from another.

Return
An expression where the ith element is x_i minus y_i
Parameters
  • x: The expression from which to subtract
  • y: The expression to subtract

Expression dynet::operator-(real x, const Expression &y)

Scalar subtraction.

Subtract an expression from a scalar

Return
An expression where the ith element is x_i minus y
Parameters
  • x: The scalar from which to subtract
  • y: The expression to subtract

Expression dynet::operator-(const Expression &x, real y)

Scalar subtraction.

Subtract a scalar from an expression

Return
An expression where the ith element is x_i minus y
Parameters
  • x: The expression from which to subtract
  • y: The scalar to subtract

Expression dynet::operator*(const Expression &x, const Expression &y)

Matrix multiplication.

Multiply two matrices together. Like standard matrix multiplication, the second dimension of x and the first dimension of y must match.

Return
An expression x times y
Parameters
  • x: The left-hand matrix
  • y: The right-hand matrix

Expression dynet::operator*(const Expression &x, float y)

Matrix-scalar multiplication.

Multiply an expression component-wise by a scalar.

Return
An expression where the ith element is x_i times y
Parameters
  • x: The matrix
  • y: The scalar

Expression dynet::operator*(float y, const Expression &x)

Matrix-scalar multiplication.

Multiply an expression component-wise by a scalar.

Return
An expression where the ith element is x_i times y
Parameters
  • x: The scalar
  • y: The matrix

Expression dynet::operator/(const Expression &x, float y)

Matrix-scalar division.

Divide an expression component-wise by a scalar.

Return
An expression where the ith element is x_i divided by y
Parameters
  • x: The matrix
  • y: The scalar

Expression dynet::affine_transform(const std::initializer_list<Expression> &xs)

Affine transform.

This performs an affine transform over an arbitrary (odd) number of expressions held in the input initializer list xs. The first expression is the “bias,” which is added to the expression as-is. The remaining expressions are multiplied together in pairs, then added. A very common usage case is the calculation of the score for a neural network layer (e.g. b + Wz) where b is the bias, W is the weight matrix, and z is the input. In this case xs[0] = b, xs[1] = W, and xs[2] = z.

Return
An expression equal to: xs[0] + xs[1]*xs[2] + xs[3]*xs[4] + …
Parameters
  • xs: An initializer list containing an odd number of expressions

Expression dynet::sum(const std::initializer_list<Expression> &xs)

Sum.

This performs an elementwise sum over all the expressions in xs

Return
An expression where the ith element is equal to xs[0][i] + xs[1][i] + …
Parameters
  • xs: An initializer list containing expressions

Expression dynet::sum_elems(const Expression &x)

Sum all elements.

Sum all the elements in an expression.

Return
The sum of all of its elements
Parameters
  • x: The input expression

Expression dynet::moment_elems(const Expression &x, unsigned r)

Compute moment over all elements.

Compute the moment of order \(r\), \(\frac 1 n\sum_{i=1}^nx_i^r\) over all the elements in each batch of the expression

Return
A scalar expression (with a potential batch dimension)
Parameters
  • x: The input mini-batched expression
  • r: Order of the moment

Expression dynet::mean_elems(const Expression &x)

Compute mean over all elements.

Computes \(\frac 1 n\sum_{i=1}^nx_i\) over all the elements in each batch of the expression

Return
A scalar expression (with a potential batch dimension)
Parameters
  • x: The input mini-batched expression

Expression dynet::std_elems(const Expression &x)

Compute Standard deviation over all elements.

Computes \(\frac 1 n\sum_{i=1}^n(x_i -\mu)^2\) where \(\mu=\frac 1 n\sum_{i=1}^nx_i\) over all the elements in each batch of the expression

Return
A scalar expression (with a potential batch dimension)
Parameters
  • x: The input mini-batched expression

Expression dynet::sum_batches(const Expression &x)

Sum over minibatches.

Sum an expression that consists of multiple minibatches into one of equal dimension but with only a single minibatch. This is useful for summing loss functions at the end of minibatch training.

Return
An expression with a single batch
Parameters
  • x: The input mini-batched expression

Expression dynet::moment_batches(const Expression &x, unsigned r)

Compute moment over minibatches.

Compute the moment of order \(r\), \(\frac 1 n\sum_{i=1}^nx_i^r\) along the batch dimension

Return
An expression with a single batch
Parameters
  • x: The input mini-batched expression
  • r: Order of the moment

Expression dynet::mean_batches(const Expression &x)

Compute mean over minibatches.

Computes \(\frac 1 n\sum_{i=1}^nx_i\) along the batch dimension

Return
An expression with a single batch
Parameters
  • x: The input mini-batched expression

Expression dynet::std_batches(const Expression &x)

Compute standard deviation over minibatches.

Computes \(\frac 1 n\sum_{i=1}^n(x_i -\mu)^2\) where \(\mu=\frac 1 n\sum_{i=1}^nx_i\) along the batch dimension

Return
A scalar expression (with a potential batch dimension)
Parameters
  • x: The input mini-batched expression

Expression dynet::sum_dim(const Expression &x, const std::vector<unsigned> &dims, bool b = false)

Compute sum along a specific dimension or dimensions.

Compute the sum along a specific dimension or dimensions

Return
An expression with |d| less dimensions and possibly dropped batch dimension
Parameters
  • x: The input mini-batched expression
  • d: Dimensions along which to reduce
  • b: Whether to include batch dimension (default: false)

Expression dynet::cumsum(const Expression &x, unsigned d)

Compute cumulative sum along a specific dimension.

Compute the cumulative sum along a specific dimension: \(y_i=\sum_{j\leq i}x_j\)

Return
An expression of the same shape as the input
Parameters
  • x: The input mini-batched expression
  • d: Dimensions along which to compute the cumulative sum

Expression dynet::moment_dim(const Expression &x, const std::vector<unsigned> &dims, unsigned r, bool b = false, unsigned n = 0)

Compute moment along a specific dimension.

Compute the moment of order \(r\), \(\frac 1 n\sum_{i=1}^nx_i^r\) along a specific dimension

Return
An expression with |d| less dimensions and possibly dropped batch dimension
Parameters
  • x: The input mini-batched expression
  • d: Dimensions along which to reduce
  • r: Order of the moment
  • b: Whether to include batch dimension (default: false)
  • n: If > 0, overwrite the n in the equation by this value, useful for masking (default: 0)

Expression dynet::mean_dim(const Expression &x, const std::vector<unsigned> &dims, bool b = false, unsigned n = 0)

Compute mean along a specific dimension.

Computes \(\frac 1 n\sum_{i=1}^nx_i\) along a specific dimension

Return
An expression with |d| less dimensions and possibly dropped batch dimension
Parameters
  • x: The input mini-batched expression
  • d: Dimensions along which to reduce
  • b: Whether to include batch dimension (default: false)
  • n: If > 0, overwrite the n in the equation by this value, useful for masking (default: 0)

Expression dynet::std_dim(const Expression &x, const std::vector<unsigned> &dims, bool b = false, unsigned n = 0)

Compute standard deviation along an arbitrary dimension.

Computes \(\frac 1 n\sum_{i=1}^n(x_i -\mu)^2\) where \(\mu=\frac 1 n\sum_{i=1}^nx_i\) along an arbitrary dimension

Return
An expression with |d| less dimensions and possibly dropped batch dimension
Parameters
  • x: The input mini-batched expression
  • d: Dimensions along which to reduce
  • b: Whether to include batch dimension (default: false)
  • n: If > 0, overwrite the n in the equation by this value, useful for masking (default: 0)

Expression dynet::average(const std::initializer_list<Expression> &xs)

Average.

This performs an elementwise average over all the expressions in xs

Return
An expression where the ith element is equal to (xs[0][i] + xs[1][i] + …)/|xs|
Parameters
  • xs: An initializer list containing expressions

Expression dynet::sqrt(const Expression &x)

Square root.

Elementwise square root.

Return
An expression where the ith element is equal to \(\sqrt(x_i)\)
Parameters
  • x: The input expression

Expression dynet::abs(const Expression &x)

Absolute value.

Elementwise absolute value.

Return
An expression where the ith element is equal to \(\vert x_i\vert\)
Parameters
  • x: The input expression

Expression dynet::erf(const Expression &x)

Gaussian error function.

Elementwise calculation of the Gaussian error function

Return
An expression where the ith element is equal to erf(x_i)
Parameters
  • x: The input expression

Expression dynet::asin(const Expression &x)

Inverse sine.

Elementwise calculation of the inverse sine

Return
An expression where the ith element is equal to asin(x_i)
Parameters
  • x: The input expression

Expression dynet::acos(const Expression &x)

Inverse cosine.

Elementwise calculation of the inverse cosine

Return
An expression where the ith element is equal to acos(x_i)
Parameters
  • x: The input expression

Expression dynet::atan(const Expression &x)

Inverse tangent.

Elementwise calculation of the inverse tangent

Return
An expression where the ith element is equal to atan(x_i)
Parameters
  • x: The input expression

Expression dynet::sin(const Expression &x)

Sine.

Elementwise calculation of the sine

Return
An expression where the ith element is equal to sin(x_i)
Parameters
  • x: The input expression

Expression dynet::cos(const Expression &x)

Cosine.

Elementwise calculation of the cosine

Return
An expression where the ith element is equal to cos(x_i)
Parameters
  • x: The input expression

Expression dynet::tan(const Expression &x)

Tangent.

Elementwise calculation of the tangent

Return
An expression where the ith element is equal to tan(x_i)
Parameters
  • x: The input expression

Expression dynet::sinh(const Expression &x)

Hyperbolic sine.

Elementwise calculation of the hyperbolic sine

Return
An expression where the ith element is equal to sinh(x_i)
Parameters
  • x: The input expression

Expression dynet::cosh(const Expression &x)

Hyperbolic cosine.

Elementwise calculation of the hyperbolic cosine

Return
An expression where the ith element is equal to cosh(x_i)
Parameters
  • x: The input expression

Expression dynet::tanh(const Expression &x)

Hyperbolic tangent.

Elementwise calculation of the hyperbolic tangent

Return
An expression where the ith element is equal to tanh(x_i)
Parameters
  • x: The input expression

Expression dynet::asinh(const Expression &x)

Inverse hyperbolic sine.

Elementwise calculation of the inverse hyperbolic sine

Return
An expression where the ith element is equal to asinh(x_i)
Parameters
  • x: The input expression

Expression dynet::acosh(const Expression &x)

Inverse hyperbolic cosine.

Elementwise calculation of the inverse hyperbolic cosine

Return
An expression where the ith element is equal to acosh(x_i)
Parameters
  • x: The input expression

Expression dynet::atanh(const Expression &x)

Inverse hyperbolic tangent.

Elementwise calculation of the inverse hyperbolic tangent

Return
An expression where the ith element is equal to atanh(x_i)
Parameters
  • x: The input expression

Expression dynet::exp(const Expression &x)

Natural exponent.

Calculate elementwise y_i = e^{x_i}

Return
An expression where the ith element is equal to e^{x_i}
Parameters
  • x: The input expression

Expression dynet::square(const Expression &x)

Square.

Calculate elementwise y_i = x_i^2

Return
An expression where the ith element is equal to x_i^2
Parameters
  • x: The input expression

Expression dynet::cube(const Expression &x)

Cube.

Calculate elementwise y_i = x_i^3

Return
An expression where the ith element is equal to x_i^3
Parameters
  • x: The input expression

Expression dynet::log_sigmoid(const Expression &x)

Log sigmoid.

Calculate elementwise \(y_i = \ln(\frac{1}{1+e^{-x_i}})\) This is more numerically stable than log(logistic(x))

Return
An expression where the ith element is equal to \(y_i = \ln(\frac{1}{1+e^{-x_i}})\)
Parameters
  • x: The input expression

Expression dynet::lgamma(const Expression &x)

Log gamma.

Calculate elementwise y_i = ln(gamma(x_i))

Return
An expression where the ith element is equal to ln(gamma(x_i))
Parameters
  • x: The input expression

Expression dynet::log(const Expression &x)

Logarithm.

Calculate the elementwise natural logarithm y_i = ln(x_i)

Return
An expression where the ith element is equal to ln(x_i)
Parameters
  • x: The input expression

Expression dynet::logistic(const Expression &x)

Logistic sigmoid function.

Calculate elementwise y_i = 1/(1+e^{-x_i})

Return
An expression where the ith element is equal to y_i = 1/(1+e^{-x_i})
Parameters
  • x: The input expression

Expression dynet::rectify(const Expression &x)

Rectifier.

Calculate elementwise the recitifer (ReLU) function y_i = max(x_i,0)

Return
An expression where the ith element is equal to max(x_i,0)
Parameters
  • x: The input expression

Expression dynet::elu(const Expression &x, float alpha = 1.f)

Exponential Linear Unit.

Calculate elementwise the function

\( y_i = \left\{\begin{array}{lr} x_i, & \text{if } x>0\\ \alpha\times(e^{x_i} - 1), & \text{if }x\leqslant 0\\ \end{array}\right. \)

Reference: Clevert et al., 2015

Return
An expression where the ith element is equal to \(\text{ELU}(x_i, \alpha)\)
Parameters
  • x: The input expression

Expression dynet::selu(const Expression &x)

Scaled Exponential Linear Unit (SELU)

Calculate elementwise the function

\( y_i = \lambda\times\left\{\begin{array}{lr} x_i, & \text{if } x>0\\ \alpha\times(e^{x_i} - 1), & \text{if }x\leqslant 0\\ \end{array}\right. \)

With \( \begin{split} \lambda &=\texttt{1.0507009873554804934193349852946}\\ \alpha &=\texttt{1.6732632423543772848170429916717}\\ \end{split} \)

Reference: Klambaouer et al., 2017

Return
An expression where the ith element is equal to \(\text{SELU}(x_i)\)
Parameters
  • x: The input expression

Expression dynet::silu(const Expression &x, float beta = 1.f)

SILU / SiL / Swish.

Calculate elementwise y_i = x_i / (1 + e^{-beta * x_i})

Reference: Hendrycks and Gimpel, 2016, Elfwing et al, 2017, and Ramachandran et al., 2017

Return
An expression where the ith element is equal to y_i = x_i / (1 + e^{-beta * x_i})
Parameters
  • x: The input expression

Expression dynet::softsign(const Expression &x)

Soft Sign.

Calculate elementwise the softsign function y_i = x_i/(1+|x_i|)

Return
An expression where the ith element is equal to x_i/(1+|x_i|)
Parameters
  • x: The input expression

Expression dynet::pow(const Expression &x, const Expression &y)

Power function.

Calculate an output where the ith element is equal to x_i^y_i

Return
An expression where the ith element is equal to x_i^y_i
Parameters
  • x: The input expression
  • y: The exponent expression

Expression dynet::min(const Expression &x, const Expression &y)

Minimum.

Calculate an output where the ith element is min(x_i,y_i)

Return
An expression where the ith element is equal to min(x_i,y_i)
Parameters
  • x: The first input expression
  • y: The second input expression

Expression dynet::max(const Expression &x, const Expression &y)

Maximum.

Calculate an output where the ith element is max(x_i,y_i)

Return
An expression where the ith element is equal to max(x_i,y_i)
Parameters
  • x: The first input expression
  • y: The second input expression

Expression dynet::max(const std::initializer_list<Expression> &xs)

Max.

This performs an elementwise max over all the expressions in xs

Return
An expression where the ith element is equal to max(xs[0][i], xs[1][i], …)
Parameters
  • xs: An initializer list containing expressions

Expression dynet::dot_product(const Expression &x, const Expression &y)

Dot Product.

Calculate the dot product sum_i x_i*y_i

Return
An expression equal to the dot product
Parameters
  • x: The input expression
  • y: The input expression

Expression dynet::circ_conv(const Expression &u, const Expression &v)

Circular convolution.

Calculate the circular convolution

Return
An expression equal to the circular convolution
Parameters
  • x: The input expression
  • y: The input expression

Expression dynet::circ_corr(const Expression &u, const Expression &v)

Circular correlation.

Calculate the circular correlation

Return
An expression equal to the circular correlation
Parameters
  • x: The input expression
  • y: The input expression

Expression dynet::cmult(const Expression &x, const Expression &y)

Componentwise multiply.

Multiply two expressions component-wise, broadcasting dimensions if necessary as follows:

  • When number of dimensions differ, we add dimensions of size 1 to make the number of dimensions match
  • Now, every dimensions is required to have matching size, or one of the dimensions must equal 1 (in which case it will be broadcasted)
  • In the same way, the batch dimension must match, or equal 1 in which case it will be broadcasted
  • The resulting tensor’s dimensionality is thus determined as the max of both inputs at every position

Return
An expression where the ith element is equal to x_i*y_i
Parameters
  • x: The first input expression
  • y: The second input expression

Expression dynet::cdiv(const Expression &x, const Expression &y)

Componentwise division.

Divide an expressions component-wise by another, broadcasting dimensions (currently only of the second expression!) if necessary as follows:

  • When number of dimensions differ, we add dimensions of size 1 to make the number of dimensions match
  • Now, every dimensions is required to have matching size, or the dim size of the right expression must equal 1 (in which case it will be broadcasted)
  • In the same way, the batch sizes must match, or the batch size of the right expression must equal 1 in which case it will be broadcasted
  • The resulting tensor’s dimensionality is thus determined as the max of both inputs at every position

Return
An expression where the ith element is equal to x_i/y_i
Parameters
  • x: The first input expression
  • y: The second input expression

Expression dynet::colwise_add(const Expression &x, const Expression &bias)

Columnwise addition.

Add vector “bias” to each column of matrix “x”

Return
An expression where bias is added to each column of x
Parameters
  • x: An MxN matrix
  • bias: A length M vector

Probability/Loss Operations

These operations are used for calculating probabilities, or calculating loss functions for use in training.

Expression dynet::softmax(const Expression &x, unsigned d = 0)

Softmax.

The softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the e^{x[i]}/{sum_j e^{x[j]}}.

Return
A vector or matrix after calculating the softmax
Parameters
  • x: A vector or matrix
  • d: dimension to normalize over (default: 0)

Expression dynet::log_softmax(const Expression &x)

Log softmax.

The log softmax function normalizes each column to ensure that all values are between 0 and 1 and add to one by applying the e^{x[i]}/{sum_j e^{x[j]}}, then takes the log

Return
A vector or matrix after calculating the log softmax
Parameters
  • x: A vector or matrix

Expression dynet::log_softmax(const Expression &x, const std::vector<unsigned> &restriction)

Restricted log softmax.

The log softmax function calculated over only a subset of the vector elements. The elements to be included are set by the restriction variable. All elements not included in restriction are set to negative infinity.

Return
A vector with the log softmax over the specified elements
Parameters
  • x: A vector over which to calculate the softmax
  • restriction: The elements over which to calculate the softmax

Expression dynet::logsumexp_dim(const Expression &x, unsigned d)

Log, sum, exp by dimension.

The “logsumexp” function calculated over a particular dimension \(ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.

Return
The result.
Parameters
  • x: Expression with respect to which to calculate the logsumexp.
  • d: The dimension along which to do the logsumexp.

Expression dynet::logsumexp(const std::initializer_list<Expression> &xs)

Log, sum, exp.

The elementwise “logsumexp” function that calculates \(ln(\sum_i e^{xs_i})\), used in adding probabilities in the log domain.

Return
The result.
Parameters
  • xs: Expressions with respect to which to calculate the logsumexp.

Expression dynet::pickneglogsoftmax(const Expression &x, unsigned v)

Negative softmax log likelihood.

This function takes in a vector of scores x, and performs a log softmax, takes the negative, and selects the likelihood corresponding to the element v. This is perhaps the most standard loss function for training neural networks to predict one out of a set of elements.

Return
The negative log likelihood of element v after taking the softmax
Parameters
  • x: A vector of scores
  • v: The element with which to calculate the loss

Expression dynet::pickneglogsoftmax(const Expression &x, const unsigned *pv)

Modifiable negative softmax log likelihood.

This function calculates the negative log likelihood after the softmax with respect to index *pv. This computes the same value as the previous function that passes the index v by value, but instead passes by pointer so the value *pv can be modified without re-constructing the computation graph. This can be used in situations where we want to create a computation graph once, then feed it different data points.

Return
The negative log likelihood of element *pv after taking the softmax
Parameters
  • x: A vector of scores
  • pv: A pointer to the index of the correct element

Expression dynet::pickneglogsoftmax(const Expression &x, const std::vector<unsigned> &v)

Batched negative softmax log likelihood.

This function is similar to standard pickneglogsoftmax, but calculates loss with respect to multiple batch elements. The input will be a mini-batch of score vectors where the number of batch elements is equal to the number of indices in v.

Return
The negative log likelihoods over all the batch elements
Parameters
  • x: An expression with vectors of scores over N batch elements
  • v: A size-N vector indicating the index with respect to all the batch elements

Expression dynet::pickneglogsoftmax(const Expression &x, const std::vector<unsigned> *pv)

Modifiable batched negative softmax log likelihood.

This function is a combination of modifiable pickneglogsoftmax and batched pickneglogsoftmax: pv can be modified without re-creating the computation graph.

Return
The negative log likelihoods over all the batch elements
Parameters
  • x: An expression with vectors of scores over N batch elements
  • pv: A pointer to the indexes

Expression dynet::hinge(const Expression &x, unsigned index, float m = 1.0)

Hinge loss.

This expression calculates the hinge loss, formally expressed as: \( \text{hinge}(x,index,m) = \sum_{i \ne index} \max(0, m-x[index]+x[i]). \)

Return
The hinge loss of candidate index with respect to margin m
Parameters
  • x: A vector of scores
  • index: The index of the correct candidate
  • m: The margin

Expression dynet::hinge(const Expression &x, const unsigned *pindex, float m = 1.0)

Modifiable hinge loss.

This function calculates the hinge loss with with respect to index *pindex. This computes the same value as the previous function that passes the index index by value, but instead passes by pointer so the value *pindex can be modified without re-constructing the computation graph. This can be used in situations where we want to create a computation graph once, then feed it different data points.

Return
The hinge loss of candidate *pindex with respect to margin m
Parameters
  • x: A vector of scores
  • pindex: A pointer to the index of the correct candidate
  • m: The margin

Expression dynet::hinge(const Expression &x, const std::vector<unsigned> &indices, float m = 1.0)

Batched hinge loss.

The same as hinge loss, but for the case where x is a mini-batched tensor with indices.size() batch elements, and indices is a vector indicating the index of each of the correct elements for these elements.

Return
The hinge loss of each mini-batch
Parameters
  • x: A mini-batch of vectors with indices.size() batch elements
  • indices: The indices of the correct candidates for each batch element
  • m: The margin

Expression dynet::hinge(const Expression &x, const std::vector<unsigned> *pindices, float m = 1.0)

Batched modifiable hinge loss.

A combination of the previous batched and modifiable hinge loss functions, where vector *pindices can be modified.

Return
The hinge loss of each mini-batch
Parameters
  • x: A mini-batch of vectors with indices.size() batch elements
  • pindices: Pointer to the indices of the correct candidates for each batch element
  • m: The margin

Expression dynet::hinge_dim(const Expression &x, const std::vector<unsigned> &indices, unsigned d = 0, float m = 1.0)

Dimensionwise hinge loss.

This expression calculates the hinge loss over a particular dimension d.

Return
A vector of hinge losses for each index in indices.
Parameters
  • x: A matrix of scores
  • indices: The indices of the correct candidate (equal in length to the dimension not specified by “d”)
  • d: The dimension over which to calculate the loss (0 or 1)
  • m: The margin

Expression dynet::hinge_dim(const Expression &x, const std::vector<unsigned> *pindex, unsigned d = 0, float m = 1.0)

Modifiable dimensionwise hinge loss.

This function calculates the modifiable version of dimensionwise hinge loss.

Return
A vector of hinge losses for each index in indices.
Parameters
  • x: A vector of scores
  • pindex: A pointer to the index of the correct candidate
  • d: The dimension over which to calculate the loss (0 or 1)
  • m: The margin

Expression dynet::hinge_dim(const Expression &x, const std::vector<std::vector<unsigned>> &indices, unsigned d = 0, float m = 1.0)

Batched dimensionwise hinge loss.

The same as dimensionwise hinge loss, but for the case where x is a mini-batched tensor with indices.size() batch elements.

Return
A vector of hinge losses for each mini-batch
Parameters
  • x: A mini-batch of vectors with indices.size() batch elements
  • indices: The indices of the correct candidates for each batch element
  • d: The dimension over which to calculate the loss (0 or 1)
  • m: The margin

Expression dynet::hinge_dim(const Expression &x, const std::vector<std::vector<unsigned>> *pindices, unsigned d = 0, float m = 1.0)

Batched modifiable hinge loss.

A combination of the previous batched and modifiable hinge loss functions, where vector *pindices can be modified.

Return
The hinge loss of each mini-batch
Parameters
  • x: A mini-batch of vectors with indices.size() batch elements
  • pindices: Pointer to the indices of the correct candidates for each batch element
  • d: The dimension over which to calculate the loss (0 or 1)
  • m: The margin

Expression dynet::sparsemax(const Expression &x)

Sparsemax.

The sparsemax function (Martins et al. 2016), which is similar to softmax, but induces sparse solutions where most of the vector elements are zero. Note: This function is not yet implemented on GPU.

Return
The sparsemax of the scores
Parameters
  • x: A vector of scores

Expression dynet::sparsemax_loss(const Expression &x, const std::vector<unsigned> &target_support)

Sparsemax loss.

The sparsemax loss function (Martins et al. 2016), which is similar to softmax loss, but induces sparse solutions where most of the vector elements are zero. It has a gradient similar to the sparsemax function and thus is useful for optimizing when the sparsemax will be used at test time. Note: This function is not yet implemented on GPU.

Return
The sparsemax loss of the labels
Parameters
  • x: A vector of scores
  • target_support: The target correct labels.

Expression dynet::sparsemax_loss(const Expression &x, const std::vector<unsigned> *ptarget_support)

Modifiable sparsemax loss.

Similar to the sparsemax loss, but with ptarget_support being a pointer to a vector, allowing it to be modified without re-creating the compuation graph. Note: This function is not yet implemented on GPU.

Return
The sparsemax loss of the labels
Parameters
  • x: A vector of scores
  • ptarget_support: A pointer to the target correct labels.

Expression dynet::constrained_softmax(const Expression &x, const Expression &y)

Constrained softmax.

The constrained softmax function. Note: This function is not yet implemented on GPU.

Return
The constrained softmax of the scores.
Parameters
  • x: A vector of scores
  • y: A vector of upper bound constraints on probabilities

Expression dynet::squared_norm(const Expression &x)

Squared norm.

The squared L2 norm of the values of x: \(\sum_i x_i^2\).

Return
The squared L2 norm
Parameters
  • x: A vector of values

Expression dynet::l2_norm(const Expression &x)

L2 norm.

The L2 norm of the values of x: \(\sum_i x_i^2\).

Return
The L2 norm
Parameters
  • x: A vector of values

Expression dynet::squared_distance(const Expression &x, const Expression &y)

Squared distance.

The squared distance between values of x and y: \(\sum_i (x_i-y_i)^2\).

Return
The squared distance
Parameters
  • x: A vector of values
  • y: Another vector of values

Expression dynet::l1_distance(const Expression &x, const Expression &y)

L1 distance.

The L1 distance between values of x and y: \(\sum_i |x_i-y_i|\).

Return
The squared distance
Parameters
  • x: A vector of values
  • y: Another vector of values

Expression dynet::huber_distance(const Expression &x, const Expression &y, float c = 1.345f)

Huber distance.

The huber distance between values of x and y parameterized by c, \(\sum_i L_c(x_i, y_i)\) where:

\( L_c(x, y) = \begin{cases}{lr} \frac{1}{2}(y - x)^2 & \textrm{for } |y - f(x)| \le c, \\ c\, |y - f(x)| - \frac{1}{2}c^2 & \textrm{otherwise.} \end{cases} \)

Return
The huber distance
Parameters
  • x: A vector of values
  • y: Another vector of values
  • c: The parameter of the huber distance parameterizing the cuttoff

Expression dynet::binary_log_loss(const Expression &x, const Expression &y)

Binary log loss.

The log loss of a binary decision according to the sigmoid sigmoid function \(- \sum_i (y_i * ln(x_i) + (1-y_i) * ln(1-x_i)) \)

Return
The log loss of the sigmoid function
Parameters
  • x: A vector of values
  • y: A vector of true answers

Expression dynet::pairwise_rank_loss(const Expression &x, const Expression &y, real m = 1.0)

Pairwise rank loss.

A margin-based loss, where every margin violation for each pair of values is penalized: \(\sum_i max(m - x_i + y_i, 0)\)

Return
The pairwise rank loss
Parameters
  • x: A vector of values
  • y: A vector of true answers
  • m: The margin

Expression dynet::poisson_loss(const Expression &x, unsigned y)

Poisson loss.

The negative log probability of y according to a Poisson distribution with parameter x. Useful in Poisson regression where, we try to predict the parameters of a Possion distribution to maximize the probability of data y.

Return
The Poisson loss
Parameters
  • x: The parameter of the Poisson distribution.
  • y: The target value

Expression dynet::poisson_loss(const Expression &x, const unsigned *py)

Modifiable Poisson loss.

Similar to Poisson loss, but with the target value passed by pointer so that it can be modified without re-constructing the computation graph.

Return
The Poisson loss
Parameters
  • x: The parameter of the Poisson distribution.
  • py: A pointer to the target value

Flow/Shaping Operations

These operations control the flow of information through the graph, or the shape of the vectors/tensors used in the graph.

enum flowoperations::ArgmaxGradient

Gradient modes for the argmax operation.

Values:

flowoperationszero_gradient
flowoperationsstraight_through_gradient
Expression dynet::nobackprop(const Expression &x)

Prevent backprop.

This node has no effect on the forward pass, but prevents gradients from flowing backward during the backward pass. This is useful when there’s a subgraph for which you don’t want loss passed back to the parameters.

Return
The new expression
Parameters
  • x: The input expression

Expression dynet::flip_gradient(const Expression &x)

Flip gradient.

This node has no effect on the forward pass, but inverts the gradient on backprop. This operation is widely used in adversarial networks.

Return
An output expression containing the same as input (only effects the backprop process)
Parameters
  • x: The input expression

Expression dynet::scale_gradient(const Expression &x, float lambd = 1.0f)

Scale gradient by constant.

This node has no effect on the forward pass, but scales the gradient by lambda on backprop

Return
An output expression containing the same as input (only effects the backprop process)
Parameters
  • x: The input expression

Expression dynet::argmax(const Expression &x, ArgmaxGradient gradient_mode)

Argmax.

This node takes an input vector \(x\) and returns a one hot vector \(y\) such that \(y_{\text{argmax} x}=1\)

There are two gradient modes for this operation:

argmax(x, zero_gradient)

is the standard argmax operation. Note that this almost everywhere differentiable and its gradient is 0. It will stop your gradient

argmax(x, straight_through_gradient)

This gradient mode implements the straight-through estimator (Bengio et al., 2013). Its forward pass is the same as the argmax operation, but its gradient is the same as the identity function. Note that this does not technically correspond to a differentiable function (hence the name “estimator”).

Tensors of order \(>1\) are not supported yet

Return
The one hot argmax vector
Parameters
  • x: The input vector (can be batched)
  • gradient_mode: Specify the gradient type (zero or straight-through)

Expression dynet::reshape(const Expression &x, const Dim &d)

Reshape to another size.

This node reshapes a tensor to another size, without changing the underlying layout of the data. The layout of the data in DyNet is column-major, so if we have a 3x4 matrix

\( \begin{pmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} \\ \end{pmatrix} \)

    and transform it into a 2x6 matrix, it will be rearranged as:

\( \begin{pmatrix} x_{1,1} & x_{3,1} & x_{2,2} & x_{1,3} & x_{3,3} & x_{2,4} \\ x_{2,1} & x_{1,2} & x_{3,2} & x_{2,3} & x_{1,4} & x_{3,4} \\ \end{pmatrix} \)

   **Note:** This is O(1) for forward, and O(n) for backward.

Return
The reshaped expression
Parameters
  • x: The input expression
  • d: The new dimensions

Expression dynet::transpose(const Expression & x, const std::vector< unsigned > & dims = {1, 0})

Transpose a matrix.

Transpose a matrix or tensor, or if dims is specified shuffle the dimensions arbitrarily. Note: This is O(1) if either the row or column dimension is 1, and O(n) otherwise.

Return
The transposed/shuffled expression
Parameters
  • x: The input expression
  • dims: The dimensions to swap. The ith dimension of the output will be equal to the dims[i] dimension of the input. dims must have the same number of dimensions as x.

Expression dynet::select_rows(const Expression &x, const std::vector<unsigned> &rows)

Select rows.

Select a subset of rows of a matrix.

Return
An expression containing the selected rows
Parameters
  • x: The input expression
  • rows: The rows to extract

Expression dynet::select_rows(const Expression &x, const std::vector<unsigned> *prows)

Modifiable select rows.

Select a subset of rows of a matrix, where the elements of prows can be modified without re-creating the computation graph.

Return
An expression containing the selected rows
Parameters
  • x: The input expression
  • prows: The rows to extract

Expression dynet::select_cols(const Expression &x, const std::vector<unsigned> &cols)

Select columns.

Select a subset of columns of a matrix. select_cols is more efficient than select_rows since DyNet uses column-major order.

Return
An expression containing the selected columns
Parameters
  • x: The input expression
  • columns: The columns to extract

Expression dynet::select_cols(const Expression &x, const std::vector<unsigned> *pcols)

Modifiable select columns.

Select a subset of columns of a matrix, where the elements of pcols can be modified without re-creating the computation graph.

Return
An expression containing the selected columns
Parameters
  • x: The input expression
  • pcolumns: The columns to extract

Expression dynet::pick(const Expression &x, unsigned v, unsigned d = 0)

Pick element.

Pick a single element/row/column/sub-tensor from an expression. This will result in the dimension of the tensor being reduced by 1.

Return
The value of x[v] along dimension d
Parameters
  • x: The input expression
  • v: The index of the element to select
  • d: The dimension along which to choose the element

Expression dynet::pick(const Expression &x, const std::vector<unsigned> &v, unsigned d = 0)

Batched pick.

Pick elements from multiple batches.

Return
A mini-batched expression containing the picked elements
Parameters
  • x: The input expression
  • v: A vector of indicies to choose, one for each batch in the input expression.
  • d: The dimension along which to choose the elements

Expression dynet::pick(const Expression &x, const unsigned *pv, unsigned d = 0)

Modifiable pick element.

Pick a single element from an expression, where the index is passed by pointer so we do not need to re-create the computation graph every time.

Return
The value of x[*pv]
Parameters
  • x: The input expression
  • pv: Pointer to the index of the element to select
  • d: The dimension along which to choose the elements

Expression dynet::pick(const Expression &x, const std::vector<unsigned> *pv, unsigned d = 0)

Modifiable batched pick element.

Pick multiple elements from an input expression, where the indices are passed by pointer so we do not need to re-create the computation graph every time.

Return
A mini-batched expression containing the picked elements
Parameters
  • x: The input expression
  • pv: A pointer to vector of indicies to choose
  • d: The dimension along which to choose the elements

Expression dynet::pick_range(const Expression &x, unsigned s, unsigned e, unsigned d = 0)

Pick range of elements.

Pick a range of elements from an expression.

Return
The value of {x[v],…,x[u]}
Parameters
  • x: The input expression
  • s: The start index
  • e: The end index
  • d: The dimension along which to pick

Expression dynet::pick_batch_elem(const Expression &x, unsigned v)

(Modifiable) Pick batch element.

Pick batch element from a batched expression. For a Tensor with 3 batch elements:

\( \begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)

pick_batch_elem(t, 1) will return a Tensor of

\( \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \)

Return
The expression of picked batch element. The picked element is a tensor whose bd equals to one.
Parameters
  • x: The input expression
  • v: The index of the batch element to be picked.

Expression dynet::pick_batch_elems(const Expression &x, const std::vector<unsigned> &v)

(Modifiable) Pick batch elements.

Pick several batch elements from a batched expression. For a Tensor with 3 batch elements:

\( \begin{pmatrix} x_{1,1,1} & x_{1,1,2} \\ x_{1,2,1} & x_{1,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)

pick_batch_elems(t, {1, 2}) will return a Tensor of with 2 batch elements:

\( \begin{pmatrix} x_{2,1,1} & x_{2,1,2} \\ x_{2,2,1} & x_{2,2,2} \\ \end{pmatrix} \begin{pmatrix} x_{3,1,1} & x_{3,1,2} \\ x_{3,2,1} & x_{3,2,2} \\ \end{pmatrix} \)

Return
The expression of picked batch elements. The batch elements is a tensor whose bd equals to the size of vector v.
Parameters
  • x: The input expression
  • v: A vector of indicies of the batch elements to be picked.

Expression dynet::pick_batch_elem(const Expression &x, const unsigned *v)

Pick batch element.

Pick batch element from a batched expression.

Return
The expression of picked batch element. The picked element is a tensor whose bd equals to one.
Parameters
  • x: The input expression
  • v: A pointer to the index of the correct element to be picked.

Expression dynet::pick_batch_elems(const Expression &x, const std::vector<unsigned> *pv)

Pick batch elements.

Pick several batch elements from a batched expression.

Return
The expression of picked batch elements. The batch elements is a tensor whose bd equals to the size of vector v.
Parameters
  • x: The input expression
  • v: A pointer to the indexes

Expression dynet::concatenate_to_batch(const std::initializer_list<Expression> &xs)

Concatenate list of expressions to a single batched expression.

Perform a concatenation of several expressions along the batch dimension. All expressions must have the same shape except for the batch dimension.

Return
The expression with the batch dimensions concatenated
Parameters
  • xs: The input expressions

Expression dynet::strided_select(const Expression &x, const std::vector<int> &strides, const std::vector<int> &from = {}, const std::vector<int> &to = {})

Strided select in multiple dimensions.

Select a range and/or stride of elements from an expression.

Return
The value of x[from[0]:to[0]:strides[0],..] (as it would be in numpy syntax)
Parameters
  • x: The input expression
  • strides: List of strides for each dimension, must be >= 1. Dimensions not included default to 1. Batch dimension can be included as very last dimension.
  • from: List of 0-based offsets (inclusive) for each dimension, must be >= 0. Dimensions not included default to 0. Batch dimension can be included as very last dimension.
  • to: List of highest 0-based index to select (exclusive) for each dimension, must be >= 0. Dimensions not included default to the corresponding dim size. Batch dimension can be included as very last dimension.

Expression dynet::concatenate_cols(const std::initializer_list<Expression> &xs)

Concatenate columns.

Perform a concatenation of the columns in multiple expressions. All expressions must have the same number of rows.

Return
The expression with the columns concatenated
Parameters
  • xs: The input expressions

Expression dynet::concatenate(const std::initializer_list<Expression> &xs, unsigned d = 0)

Concatenate.

Perform a concatenation of multiple expressions along a particular dimension. All expressions must have the same dimensions except for the dimension to be concatenated (rows by default).

Return
The expression with the specified dimension concatenated
Parameters
  • xs: The input expressions
  • d: The dimension along which to perform concatenation

Expression dynet::max_dim(const Expression &x, unsigned d = 0)

Max out through a dimension.

Select out a element/row/column/sub-tensor from an expression, with maximum value along a given dimension. This will result in the dimension of the tensor being reduced by 1.

Return
An expression of sub-tensor with max value along dimension d
Parameters
  • x: The input expression
  • d: The dimension along which to choose the element

Expression dynet::min_dim(const Expression &x, unsigned d = 0)

Min out through a dimension.

Select out a element/row/column/sub-tensor from an expression, with minimum value along a given dimension. This will result in the dimension of the tensor being reduced by 1.

Return
An expression of sub-tensor with min value along dimension d
Parameters
  • x: The input expression
  • d: The dimension along which to choose the element

Noise Operations

These operations are used to add noise to the graph for purposes of making learning more robust.

Expression dynet::noise(const Expression &x, real stddev)

Gaussian noise.

Add gaussian noise to an expression.

Return
The noised expression
Parameters
  • x: The input expression
  • stddev: The standard deviation of the gaussian

Expression dynet::dropout(const Expression &x, real p)

Dropout.

With a fixed probability, drop out (set to zero) nodes in the input expression, and scale the remaining nodes by 1/p. Note that there are two kinds of dropout:

  • Regular dropout: where we perform dropout at training time and then scale outputs by p at test time.
  • Inverted dropout: where we perform dropout and scaling at training time, and do not need to do anything at test time. DyNet implements the latter, so you only need to apply dropout at training time, and do not need to perform scaling and test time.

Return
The dropped out expression
Parameters
  • x: The input expression
  • p: The dropout probability

Expression dynet::dropout_dim(const Expression &x, unsigned d, real p)

Dropout along a specific dimension.

Identical to the dropout operation except the dropout mask is the same across one dimension. Use this if you want to drop columns or lines in a matrix for example

For now this only supports tensors of order <= 3 (with or without batch dimension)

Return
The dropped out expression
Parameters
  • x: The input expression
  • d: The dimension along which to drop
  • p: The dropout probability

Expression dynet::dropout_batch(const Expression &x, real p)

Dropout entire elements of a minibatch.

Identical to the dropout operation except entire batch elements are dropped

Return
The dropped out expression
Parameters
  • x: The input expression
  • p: The dropout probability

Expression dynet::block_dropout(const Expression &x, real p)

Block dropout.

Identical to the dropout operation, but either drops out all or no values in the expression, as opposed to making a decision about each value individually.

Return
The block dropout expression
Parameters
  • x: The input expression
  • p: The block dropout probability

Tensor Operations

These operations are used for performing operations on higher order tensors.

Remark: Compiling the contraction operations takes a lot of time with CUDA. For this reason, only the CPU implementation is compiled by default. If you need those operations, you need to un-comment this line in the source before compiling. TODO: make this simpler.

Expression dynet::contract3d_1d(const Expression &x, const Expression &y)

Contracts a rank 3 tensor and a rank 1 tensor into a rank 2 tensor.

The resulting tensor \(z\) has coordinates \(z_ij = \sum_k x_{ijk} y_k\)

Return
Matrix
Parameters
  • x: Rank 3 tensor
  • y: Vector

Expression dynet::contract3d_1d_1d(const Expression &x, const Expression &y, const Expression &z)

Contracts a rank 3 tensor and two rank 1 tensor into a rank 1 tensor.

This is the equivalent of calling contract3d_1d and then performing a matrix vector multiplication.

The resulting tensor \(t\) has coordinates \(t_i = \sum_{j,k} x_{ijk} y_k z_j\)

Return
Vector
Parameters
  • x: Rank 3 tensor
  • y: Vector
  • z: Vector

Expression dynet::contract3d_1d_1d(const Expression &x, const Expression &y, const Expression &z, const Expression &b)

Same as contract3d_1d_1d with an additional bias parameter.

This is the equivalent of calling contract3d_1d and then performing an affine transform.

The resulting tensor \(t\) has coordinates \(t_i = b_i + \sum_{j,k} x_{ijk} y_k z_j\)

Return
Vector
Parameters
  • x: Rank 3 tensor
  • y: Vector
  • z: Vector
  • b: Bias vector

Expression dynet::contract3d_1d(const Expression &x, const Expression &y, const Expression &b)

Same as contract3d_1d with an additional bias parameter.

The resulting tensor \(z\) has coordinates \(z_{ij} = b_{ij}+\sum_k x_{ijk} y_k\)

Return
Matrix
Parameters
  • x: Rank 3 tensor
  • y: Vector
  • b: Bias matrix

Linear Algebra Operations

These operations are used for performing various operations common in linear algebra.

Expression dynet::inverse(const Expression &x)

Matrix Inverse.

Takes the inverse of a matrix (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158). Note that back-propagating through an inverted matrix can also be the source of stability problems sometimes.

Return
The inverse of the matrix
Parameters
  • x: A square matrix

Expression dynet::logdet(const Expression &x)

Log determinant.

Takes the log of the determinant of a matrix. (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158).

Return
The log of its determinant
Parameters
  • x: A square matrix

Expression dynet::trace_of_product(const Expression &x, const Expression &y)

Trace of Matrix Product.

Takes the trace of the product of matrices. (not implemented on GPU yet, although contributions are welcome: https://github.com/clab/dynet/issues/158).

Return
trace(x1 * x2)
Parameters
  • x1: A matrix
  • x2: Another matrix

Convolution Operations

These operations are convolution-related.

Expression dynet::conv2d(const Expression &x, const Expression &f, const std::vector<unsigned> &stride, bool is_valid = true)

conv2d without bias

2D convolution operator without bias parameters. ‘VALID’ and ‘SAME’ convolutions are supported. Think about when stride is 1, the distinction:

  • SAME: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.
  • VALID: output size shrinks by filter_size - 1, and the filters always sweep at valid positions inside the input maps. No padding needed.

In detail, assume:

  • Input feature maps: (XH x XW x XC) x N
  • Filters: FH x FW x XC x FC, 4D tensor
  • Strides: strides[0] and strides[1] are row (h) and col (w) stride, respectively.

For the SAME convolution: the output height (YH) and width (YW) are computed as:

  • YH = ceil(float(XH) / float(strides[0]))
  • YW = ceil(float(XW) / float(strides[1])) and the paddings are computed as:
  • pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
  • pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
  • pad_top = pad_along_height / 2
  • pad_bottom = pad_along_height - pad_top
  • pad_left = pad_along_width / 2
  • pad_right = pad_along_width - pad_left

For the VALID convolution: the output height (YH) and width (YW) are computed as:

  • YH = ceil(float(XH - FH + 1) / float(strides[0]))
  • YW = ceil(float(XW - FW + 1) / float(strides[1])) and the paddings are always zeros.

Return
The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Parameters
  • x: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
  • f: 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensor
  • stride: the row and column strides
  • is_valid: ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’)

Expression dynet::conv2d(const Expression &x, const Expression &f, const Expression &b, const std::vector<unsigned> &stride, bool is_valid = true)

conv2d with bias

2D convolution operator with bias parameters. ‘VALID’ and ‘SAME’ convolutions are supported. Think about when stride is 1, the distinction:

  • SAME: output size is the same with input size. To do so, one needs to pad the input so the filter can sweep outside of the input maps.
  • VALID: output size shrinks by filter_size - 1, and the filters always sweep at valid positions inside the input maps. No padding needed.

In detail, assume:

  • Input feature maps: XH x XW x XC x N
  • Filters: FH x FW x XC x FC
  • Strides: strides[0] and strides[1] are row (h) and col (w) stride, respectively.

For the SAME convolution: the output height (YH) and width (YW) are computed as:

  • YH = ceil(float(XH) / float(strides[0]))
  • YW = ceil(float(XW) / float(strides[1])) and the paddings are computed as:
  • pad_along_height = max((YH - 1) * strides[0] + FH - XH, 0)
  • pad_along_width = max((YW - 1) * strides[1] + FW - XW, 0)
  • pad_top = pad_along_height / 2
  • pad_bottom = pad_along_height - pad_top
  • pad_left = pad_along_width / 2
  • pad_right = pad_along_width - pad_left

For the VALID convolution: the output height (YH) and width (YW) are computed as:

  • YH = ceil(float(XH - FH + 1) / float(strides[0]))
  • YW = ceil(float(XW - FW + 1) / float(strides[1])) and the paddings are always zeros.

Return
The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Parameters
  • x: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
  • f: 2D convolution filters: H x W x Ci x Co (ColMaj), 4D tensor
  • b: The bias (1D: Ci)
  • stride: the row and column strides
  • is_valid: ‘VALID’ convolution or ‘SAME’ convolution, default is True (‘VALID’)

Expression dynet::maxpooling2d(const Expression &x, const std::vector<unsigned> &ksize, const std::vector<unsigned> &stride, bool is_valid = true)

maxpooling2d

2D maxpooling operator.

Return
The output feature maps (H x W x Co) x N, 3D tensor with an optional batch dimension
Parameters
  • x: The input feature maps: (H x W x Ci) x N (ColMaj), 3D tensor with an optional batch dimension
  • ksize: the height and width of the maxpooling2d window or kernel
  • stride: the row and column strides
  • is_valid: ‘VALID’ or ‘SAME’ (see comments for conv2d) , default is True (‘VALID’)

Normalization Operations

This includes batch normalization and the likes.

Expression dynet::layer_norm(const Expression &x, const Expression &g, const Expression &b)

Layer normalization.

Performs layer normalization :

\( \begin{split} \mu &= \frac 1 n \sum_{i=1}^n x_i\\ \sigma &= \sqrt{\frac 1 n \sum_{i=1}^n (x_i-\mu)^2}\\ y&=\frac {\boldsymbol{g}} \sigma \circ (\boldsymbol{x}-\mu) + \boldsymbol{b}\\ \end{split} \)

Reference : Ba et al., 2016

Return
An expression of the same dimension as x
Parameters
  • x: Input expression (possibly batched)
  • g: Gain (same dimension as x, no batch dimension)
  • b: Bias (same dimension as x, no batch dimension)

Expression dynet::weight_norm(const Expression &w, const Expression &g)

Weight normalization.

Performs weight normalization :

\( \begin{split} \hat{w} &= g\frac{w}{\Vert w\Vert}\\ \end{split} \)

Reference : Salimans, Kingma 2016

Return
An expression of the same dimension as w
Parameters
  • w: Input expression (weight parameter)
  • g: Gain (scalar expression, usually also a parameter)

Device operations

These operations are device-related.

Expression dynet::to_device(const Expression &x, Device *device)

Copy tensor between devices.

Copy tensor from x’s device to device

Return
An expression of x’s tensor in device
Parameters
  • x: Input expression
  • device: Device to place return tensor