causally.scm.causal_mechanism.NeuralNetMechanism

class causally.scm.causal_mechanism.NeuralNetMechanism(weights_mean: float = 0.0, weights_std: float = 1.0, hidden_dim: int = 10, activation: Module = PReLU(num_parameters=1))

Nonlinear causal mechanism parametrized by a neural network.

The transformation is parametrized by a simple neural network with one hidden layer, followed by an activation function, LayerNorm, and the linear output layer.

Parameters:
  • weights_mean (float, default 0) – Average value of the initialized weights.

  • weights_std (float, default 1) – Standard deviation of the initialized weights.

  • hidden_dim (int, default 10) – Number of neurons in the hidden layer.

  • activation (nn.Module, default nn.PReLU) – The nonlinear activation function.

Attributes

model

nn.Module instance of the neural network architecture.

Methods

predict(X)

Generate the effect given the observations of the parent nodes.