Builders

Builders combine together various operations to implement more complicated things such as recurrent, LSTM networks or hierarchical softmax

RNN Builders

struct dynet::CoupledLSTMBuilder
#include <lstm.h>

CoupledLSTMBuilder creates an LSTM unit with coupled input and forget gate as well as peepholes connections.

More specifically, here are the equations for the dynamics of this cell :

\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+W_{ic}c_{t-1}+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+W_{oc}c_{t}+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)

Inherits from dynet::RNNBuilder

Public Functions

dynet::CoupledLSTMBuilderCoupledLSTMBuilder()

Default constructor.

dynet::CoupledLSTMBuilderCoupledLSTMBuilder(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model)

Constructor for the LSTMBuilder.

Parameters
  • layers: Number of layers
  • input_dim: Dimention of the input \(x_t\)
  • hidden_dim: Dimention of the hidden states \(h_t\) and \(c_t\)
  • model: ParameterCollection holding the parameters

unsigned dynet::CoupledLSTMBuildernum_h0_components() const

Number of components in h_0

For LSTMBuilder, this corresponds to 2 * layers because it includes the initial cell state \(c_0\)

Return
2 * layers

std::vector<Expression> dynet::CoupledLSTMBuilderget_s(RNNPointer i) const

Get the final state of the hidden layer.

For LSTMBuilder, this consists of a vector of the memory cell values for each layer (l1, l2, l3), followed by the hidden state values

Return
{c_{l1}, c_{l1}, …, h_{l1}, h_{l2}, …}

void dynet::CoupledLSTMBuilderset_dropout(float d)

Set the dropout rates to a unique value.

This has the same effect as set_dropout(d,d_h,d_c) except that all the dropout rates are set to the same value.

Parameters
  • d: Dropout rate to be applied on all of \(x,h,c\)

void dynet::CoupledLSTMBuilderset_dropout(float d, float d_h, float d_c)

Set the dropout rates.

The dropout implemented here is an adaptation of the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\), \(\mathbf{z_c}\sim \mathrm{Bernoulli}(1-d_c)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :

\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{ih}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+W_{ic}(\frac 1 {1-d_c} {\mathbf{z_c}} \circ c_{t-1})+b_i)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{ch}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ (1-i_t) + \tilde{c_t}\circ i_t\\ & = c_{t-1} + (\tilde{c_t}-c_{t-1})\circ i_t\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x} {\mathbf{z_x}} \circ x_t)+W_{oh}(\frac 1 {1-d_h} {\mathbf{z_h}} \circ h_{t-1})+W_{oc}(\frac 1 {1-d_c} {\mathbf{z_c}} \circ c_{t})+b_o)\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)

For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation

Parameters
  • d: Dropout rate \(d_x\) for the input \(x_t\)
  • d_h: Dropout rate \(d_x\) for the output \(h_t\)
  • d_c: Dropout rate \(d_x\) for the cell \(c_t\)

void dynet::CoupledLSTMBuilderdisable_dropout()

Set all dropout rates to 0.

This is equivalent to set_dropout(0) or set_dropout(0,0,0)

void dynet::CoupledLSTMBuilderset_dropout_masks(unsigned batch_size = 1)

Set dropout masks at the beginning of a sequence for a specific bathc size.

If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element

Parameters
  • batch_size: Batch size

ParameterCollection &dynet::CoupledLSTMBuilderget_parameter_collection()

Get parameters in LSTMBuilder.

struct dynet::VanillaLSTMBuilder
#include <lstm.h>

VanillaLSTM allows the creation of a “standard” LSTM, ie with decoupled input and forget gates and no peephole connections.

This cell runs according to the following dynamics :

\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)

Inherits from dynet::RNNBuilder

Public Functions

dynet::VanillaLSTMBuilderVanillaLSTMBuilder()

Default Constructor.

dynet::VanillaLSTMBuilderVanillaLSTMBuilder(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model, bool ln_lstm = false, float forget_bias = 1.f)

Constructor for the VanillaLSTMBuilder.

Parameters
  • layers: Number of layers
  • input_dim: Dimention of the input \(x_t\)
  • hidden_dim: Dimention of the hidden states \(h_t\) and \(c_t\)
  • model: ParameterCollection holding the parameters
  • ln_lstm: Whether to use layer normalization
  • forget_bias: value(float) to use as bias for the forget gate(default = 1.0)

void dynet::VanillaLSTMBuilderset_dropout(float d)

Set the dropout rates to a unique value.

This has the same effect as set_dropout(d,d_h) except that all the dropout rates are set to the same value.

Parameters
  • d: Dropout rate to be applied on all of \(x,h\)

void dynet::VanillaLSTMBuilderset_dropout(float d, float d_r)

Set the dropout rates.

The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :

\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)

For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation

Parameters
  • d: Dropout rate \(d_x\) for the input \(x_t\)
  • d_h: Dropout rate \(d_h\) for the output \(h_t\)

void dynet::VanillaLSTMBuilderdisable_dropout()

Set all dropout rates to 0.

This is equivalent to set_dropout(0) or set_dropout(0,0,0)

void dynet::VanillaLSTMBuilderset_dropout_masks(unsigned batch_size = 1)

Set dropout masks at the beginning of a sequence for a specific batch size.

If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element

Parameters
  • batch_size: Batch size

ParameterCollection &dynet::VanillaLSTMBuilderget_parameter_collection()

Get parameters in VanillaLSTMBuilder.

Return
list of points to ParameterStorage objects

struct dynet::CompactVanillaLSTMBuilder
#include <lstm.h>

VanillaLSTM allows the creation of a “standard” LSTM, ie with decoupled input and forget gates and no peephole connections.

This cell runs according to the following dynamics :

\( \begin{split} i_t & =\sigma(W_{ix}x_t+W_{ih}h_{t-1}+b_i)\\ f_t & = \sigma(W_{fx}x_t+W_{fh}h_{t-1}+b_f+1)\\ o_t & = \sigma(W_{ox}x_t+W_{oh}h_{t-1}+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}x_t+W_{ch}h_{t-1}+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)

Inherits from dynet::RNNBuilder

Public Functions

dynet::CompactVanillaLSTMBuilderCompactVanillaLSTMBuilder()

Default Constructor.

dynet::CompactVanillaLSTMBuilderCompactVanillaLSTMBuilder(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model)

Constructor for the CompactVanillaLSTMBuilder.

Parameters
  • layers: Number of layers
  • input_dim: Dimention of the input \(x_t\)
  • hidden_dim: Dimention of the hidden states \(h_t\) and \(c_t\)
  • model: ParameterCollection holding the parameters

void dynet::CompactVanillaLSTMBuilderset_dropout(float d)

Set the dropout rates to a unique value.

This has the same effect as set_dropout(d,d_h) except that all the dropout rates are set to the same value.

Parameters
  • d: Dropout rate to be applied on all of \(x,h\)

void dynet::CompactVanillaLSTMBuilderset_dropout(float d, float d_r)

Set the dropout rates.

The dropout implemented here is the variational dropout with tied weights introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\), \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :

\( \begin{split} i_t & =\sigma(W_{ix}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ih}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_i)\\ f_t & = \sigma(W_{fx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{fh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_f)\\ o_t & = \sigma(W_{ox}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{oh}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_o)\\ \tilde{c_t} & = \tanh(W_{cx}(\frac 1 {1-d_x}\mathbf{z_x} \circ x_t)+W_{ch}(\frac 1 {1-d_h}\mathbf{z_h} \circ h_{t-1})+b_c)\\ c_t & = c_{t-1}\circ f_t + \tilde{c_t}\circ i_t\\ h_t & = \tanh(c_t)\circ o_t\\ \end{split} \)

For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation

Parameters
  • d: Dropout rate \(d_x\) for the input \(x_t\)
  • d_h: Dropout rate \(d_h\) for the output \(h_t\)

void dynet::CompactVanillaLSTMBuilderdisable_dropout()

Set all dropout rates to 0.

This is equivalent to set_dropout(0) or set_dropout(0,0,0)

void dynet::CompactVanillaLSTMBuilderset_dropout_masks(unsigned batch_size = 1)

Set dropout masks at the beginning of a sequence for a specific batch size.

If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element

Parameters
  • batch_size: Batch size

void dynet::CompactVanillaLSTMBuilderset_weightnoise(float std)

Get parameters in CompactVanillaLSTMBuilder.

Return
list of points to ParameterStorage objects

struct dynet::RNNBuilder
#include <rnn.h>

interface for constructing an RNN, LSTM, GRU, etc.

[long description]

Subclassed by dynet::CompactVanillaLSTMBuilder, dynet::CoupledLSTMBuilder, dynet::DeepLSTMBuilder, dynet::FastLSTMBuilder, dynet::GRUBuilder, dynet::SimpleRNNBuilder, dynet::TreeLSTMBuilder, dynet::VanillaLSTMBuilder

Public Functions

dynet::RNNBuilderRNNBuilder()

Default constructor.

RNNPointer dynet::RNNBuilderstate() const

Get pointer to the current state.

Return
Pointer to the current state

void dynet::RNNBuildernew_graph(ComputationGraph &cg, bool update = true)

Initialize with new computation graph.

call this to reset the builder when you are working with a newly created ComputationGraph object

Parameters
  • cg: Computation graph
  • update: Update internal parameters while training

void dynet::RNNBuilderstart_new_sequence(const std::vector<Expression> &h_0 = {})

Reset for new sequence.

call this before add_input and after new_graph, when starting a new sequence on the same hypergraph.

Parameters
  • h_0: h_0 is used to initialize hidden layers at timestep 0 to given values

Expression dynet::RNNBuilderset_h(const RNNPointer &prev, const std::vector<Expression> &h_new = {})

Explicitly set the output state of a node.

Return
The hidden representation of the deepest layer
Parameters
  • prev: Pointer to the previous state
  • h_new: The new hidden state

Expression dynet::RNNBuilderset_s(const RNNPointer &prev, const std::vector<Expression> &s_new = {})

Set the internal state of a node (for lstms/grus)

For RNNs without internal states (SimpleRNN, GRU…), this has the same behaviour as set_h

Return
The hidden representation of the deepest layer
Parameters
  • prev: Pointer to the previous state
  • s_new: The new state. Can be {new_c[0],...,new_c[n]} or {new_c[0],...,new_c[n], new_h[0],...,new_h[n]}

Expression dynet::RNNBuilderadd_input(const Expression &x)

Add another timestep by reading in the variable x.

Return
The hidden representation of the deepest layer
Parameters
  • x: Input variable

Expression dynet::RNNBuilderadd_input(const RNNPointer &prev, const Expression &x)

Add another timestep, with arbitrary recurrent connection.

This allows you to define a recurrent connection to prev rather than to head[cur]. This can be used to construct trees, implement beam search, etc.

Return
The hidden representation of the deepest layer
Parameters
  • prev: Pointer to the previous state
  • x: Input variable

void dynet::RNNBuilderrewind_one_step()

Rewind the last timestep.

  • this DOES NOT remove the variables from the computation graph, it just means the next time step will see a different previous state. You can rewind as many times as you want.

RNNPointer dynet::RNNBuilderget_head(const RNNPointer &p)

Return the RNN state that is the parent of p

  • This can be used in implementing complex structures such as trees, etc.

virtual void dynet::RNNBuilderset_dropout(float d)

Set Dropout.

Parameters
  • d: Dropout rate

virtual void dynet::RNNBuilderdisable_dropout()

Disable Dropout.

In general, you should disable dropout at test time

virtual Expression dynet::RNNBuilderback() const = 0

Returns node (index) of most recent output.

Return
Node (index) of most recent output

virtual std::vector<Expression> dynet::RNNBuilderfinal_h() const = 0

Access the final output of each hidden layer.

Return
Final output of each hidden layer

virtual std::vector<Expression> dynet::RNNBuilderget_h(RNNPointer i) const = 0

Access the output of any hidden layer.

Return
Output of each hidden layer at the given step
Parameters
  • i: Pointer to the step which output you want to access

virtual std::vector<Expression> dynet::RNNBuilderfinal_s() const = 0

Access the final state of each hidden layer.

This returns the state of each hidden layer, in a format that can be used in start_new_sequence (i.e. including any internal cell for LSTMs and the likes)

Return
vector containing, if it exists, the list of final internal states, followed by the list of final outputs for each layer

virtual std::vector<Expression> dynet::RNNBuilderget_s(RNNPointer i) const = 0

Access the state of any hidden layer.

See final_s for details

Return
Internal state of each hidden layer at the given step
Parameters
  • i: Pointer to the step which state you want to access

virtual unsigned dynet::RNNBuildernum_h0_components() const = 0

Number of components in h_0

Return
Number of components in h_0

virtual void dynet::RNNBuildercopy(const RNNBuilder &params) = 0

Copy the parameters of another builder.

Parameters
  • params: RNNBuilder you want to copy parameters from.

struct dynet::SimpleRNNBuilder
#include <rnn.h>

This provides a builder for the simplest RNN with tanh nonlinearity.

The equation for this RNN is : \(h_t=\tanh(W_x x_t + W_h h_{t-1} + b)\)

Inherits from dynet::RNNBuilder

Public Functions

dynet::SimpleRNNBuilderSimpleRNNBuilder(unsigned layers, unsigned input_dim, unsigned hidden_dim, ParameterCollection &model, bool support_lags = false)

Builds a simple RNN.

Parameters
  • layers: Number of layers
  • input_dim: Dimension of the input
  • hidden_dim: Hidden layer (and output) size
  • model: ParameterCollection holding the parameters
  • support_lags: Allow for auxiliary output?

Expression dynet::SimpleRNNBuilderadd_auxiliary_input(const Expression &x, const Expression &aux)

Add auxiliary output.

Returns \(h_t=\tanh(W_x x_t + W_h h_{t-1} + W_y y + b)\) where \(y\) is an auxiliary output TODO : clarify

Return
The hidden representation of the deepest layer
Parameters
  • x: Input expression
  • aux: Auxiliary output expression

void dynet::SimpleRNNBuilderset_dropout(float d)

Set the dropout rates to a unique value.

This has the same effect as set_dropout(d,d_h) except that all the dropout rates are set to the same value.

Parameters
  • d: Dropout rate to be applied on all of \(x,h\)

void dynet::SimpleRNNBuilderset_dropout(float d, float d_h)

The dropout implemented here is the variational dropout introduced in Gal, 2016 More specifically, dropout masks \(\mathbf{z_x}\sim \mathrm{Bernoulli}(1-d_x)\) and \(\mathbf{z_h}\sim \mathrm{Bernoulli}(1-d_h)\) are sampled at the start of each sequence. The dynamics of the cell are then modified to :

Parameters
  • d: Dropout rate

\( \begin{split} h_t & =\tanh(W_{x}(\frac 1 {1-d}\mathbf{z_x} \circ x_t)+W_{h}(\frac 1 {1-d}\mathbf{z_h} \circ h_{t-1})+b)\\ \end{split} \)

For more detail as to why scaling is applied, see the “Unorthodox” section of the documentation

Parameters
  • d: Dropout rate \(d\) for the input \(x_t\)

void dynet::SimpleRNNBuilderset_dropout_masks(unsigned batch_size = 1)

Set dropout masks at the beginning of a sequence for a specific bathc size.

If this function is not called on batched input, the same mask will be applied across all batch elements. Use this to apply different masks to each batch element

Parameters
  • batch_size: Batch size

struct dynet::TreeLSTMBuilder
#include <treelstm.h>

TreeLSTMBuilder is the base class for tree structured lstm builders.

Inherits from dynet::RNNBuilder

Subclassed by dynet::BidirectionalTreeLSTMBuilder, dynet::NaryTreeLSTMBuilder, dynet::UnidirectionalTreeLSTMBuilder

Public Functions

virtual Expression dynet::TreeLSTMBuilderadd_input(int id, std::vector<int> children, const Expression &x) = 0

add input with given children at position id

if you did not call set_num_elems before, each successive id must be the previous id plus one and the children must all be smaller than id. If you used set_num_elems, id must be smaller than the number of elements and the children must have been already provided.

Parameters
  • id: index where x should be stored
  • children: indices of the children for x

virtual void dynet::TreeLSTMBuilderset_num_elements(int num) = 0

Set the number of nodes in your tree in advance.

By default, input to a TreeLSTMBuilder needs to be in ascending order, i.e. when sequentializing the nodes, all leaves have to be first. If you know the number of elements beforehand, you can call this method to then place your nodes at arbitrary indices, e.g. because you already have a sequentialization that does not conform to the leaves-first requirement.

Parameters
  • num: desired size

struct dynet::NaryTreeLSTMBuilder
#include <treelstm.h>

Builds N-ary trees with a fixed upper bound of children. See “Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks” by Tai, Nary, and Manning (2015), section 3.2, for details on this model. http://arxiv.org/pdf/1503.00075v3.pdf.

Inherits from dynet::TreeLSTMBuilder

struct dynet::UnidirectionalTreeLSTMBuilder
#include <treelstm.h>

Builds a tree-LSTM which is recursively defined by a unidirectional LSTM over the node and its children representations.

Inherits from dynet::TreeLSTMBuilder

struct dynet::BidirectionalTreeLSTMBuilder
#include <treelstm.h>

Builds a tree-LSTM which is recursively defined by a Bidirectional LSTM over the node and its children representations.

Inherits from dynet::TreeLSTMBuilder

Softmax Builders

class dynet::SoftmaxBuilder
#include <cfsm-builder.h>

Interface for building softmax layers.

A softmax layer returns a probability distribution over \(C\) classes given a vector \(h\in\mathbb R^d\), with

\(p(c)\propto \exp(W_i^Th + b_i)\ \forall i\in\{1\ldots C\}\)

Where \(W\in \mathbb R^{C\times d}, b \in \mathbb R^C\)

Subclassed by dynet::ClassFactoredSoftmaxBuilder, dynet::HierarchicalSoftmaxBuilder, dynet::StandardSoftmaxBuilder

Public Functions

virtual void dynet::SoftmaxBuildernew_graph(ComputationGraph &cg, bool update = true) = 0

This initializes the parameters in the computation graph.

Call this once per ComputationGraph before any computation with the softmax

Parameters
  • cg: Computation graph
  • update: Whether to update the parameters

virtual Expression dynet::SoftmaxBuilderneg_log_softmax(const Expression &rep, unsigned classidx) = 0

Negative log probability of a class.

Given class \(c\) and vector \(h\), this returns \(-\log(p(c \mid h))\)

Return
\(-\log(p(\texttt{class} \mid \texttt{rep}))\)
Parameters
  • rep: vector expression
  • class: Class

virtual Expression dynet::SoftmaxBuilderneg_log_softmax(const Expression &rep, const std::vector<unsigned> &classidxs) = 0

Batched version of the former.

Returns a batched scalar

Return
\(-\log(p(\texttt{class}_b \mid \texttt{rep}_b))\) for each batch element \(b\)
Parameters
  • rep: Vector expression (batched)
  • classes: List of classes, one per batch element

virtual unsigned dynet::SoftmaxBuildersample(const Expression &rep) = 0

Sample from the softmax distribution.

Return
Sampled class
Parameters
  • rep: Vector expression parametrizing the distribution

virtual Expression dynet::SoftmaxBuilderfull_log_distribution(const Expression &rep) = 0

Returns an Expression representing a vector the size of the number of classes.

The ith dimension gives \(\log p(c_i | \texttt{rep})\). This function may be SLOW. Avoid if possible.

Return
Expression of the distribution
Parameters
  • rep: Vector expression parametrizing the distribution

virtual Expression dynet::SoftmaxBuilderfull_logits(const Expression &rep) = 0

Returns the logits (before application of the softmax)

The ith dimension gives \(W_i^Th + b_i\)

Return
Expression for the logits
Parameters
  • rep: Vector expression parametrizing the distribution

virtual ParameterCollection &dynet::SoftmaxBuilderget_parameter_collection() = 0

Returns the ParameterCollection containing the softmax parameters.

Return
ParameterCollection

class dynet::StandardSoftmaxBuilder
#include <cfsm-builder.h>

This class implements the standard Softmax.

Inherits from dynet::SoftmaxBuilder

Public Functions

dynet::StandardSoftmaxBuilderStandardSoftmaxBuilder(unsigned rep_dim, unsigned num_classes, ParameterCollection &pc, bool bias = true)

Constructs a softmaxbuilder.

This creates the parameters given the dimensions

Parameters
  • rep_dim: Dimension of the input vectors
  • num_classes: Number of classes
  • pc: Parameter collection
  • bias: Whether to use a bias vector or not

dynet::StandardSoftmaxBuilderStandardSoftmaxBuilder(Parameter &p_w, Parameter &p_b)

Builds a softmax layer with pre-existing parameters.

Parameters
  • p_w: Weight matrix
  • p_b: Bias vector

dynet::StandardSoftmaxBuilderStandardSoftmaxBuilder(Parameter &p_w)

Builds a softmax layer with pre-existing parameters (no bias)

Parameters
  • p_w: Weight matrix

class dynet::ClassFactoredSoftmaxBuilder
#include <cfsm-builder.h>

Class factored softmax.

Each class is separated into a subclass, ie \(p(i\mid h)=p(i\mid h, c) p(c\mid h)\) where \(c\) is a class and \(i\) a subclass

Inherits from dynet::SoftmaxBuilder

Public Functions

dynet::ClassFactoredSoftmaxBuilderClassFactoredSoftmaxBuilder(unsigned rep_dim, const std::string &cluster_file, Dict &word_dict, ParameterCollection &pc, bool bias = true)

Constructor from file.

This constructs the CFSM from a file with lines of the following format

CLASSID   word    [freq]

For words for instance

Parameters
  • rep_dim: Dimension of the input vector
  • cluster_file: File containing classes
  • word_dict: Dictionary for words (maps words to index)
  • pc: ParameterCollection
  • bias: Whether to use a bias vector or not

Expression dynet::ClassFactoredSoftmaxBuilderclass_log_distribution(const Expression &rep)

Get log distribution over classes.

Return
Vector of \(\log(p(c\mid \texttt{rep}))\)
Parameters
  • rep: Input vector

Expression dynet::ClassFactoredSoftmaxBuilderclass_logits(const Expression &rep)

Get logits of classes.

Return
Logits
Parameters
  • rep: Input vector

Expression dynet::ClassFactoredSoftmaxBuildersubclass_log_distribution(const Expression &rep, unsigned clusteridx)

Get log distribution over subclasses of class.

Return
Vector of \(\log(p(i\mid c, \texttt{rep}))\)
Parameters
  • rep: Input vector
  • clusteridx: Class index

Expression dynet::ClassFactoredSoftmaxBuildersubclass_logits(const Expression &rep, unsigned clusteridx)

Logits over subclasses of class.

Return
Logits
Parameters
  • rep: Input vector
  • clusteridx: Class index