causally.scm.causal_mechanism.NeuralNetMechanism
- class causally.scm.causal_mechanism.NeuralNetMechanism(weights_mean: float = 0.0, weights_std: float = 1.0, hidden_dim: int = 10, activation: Module = PReLU(num_parameters=1))
Nonlinear causal mechanism parametrized by a neural network.
The transformation is parametrized by a simple neural network with one hidden layer, followed by an activation function, LayerNorm, and the linear output layer.
- Parameters:
weights_mean (float, default 0) – Average value of the initialized weights.
weights_std (float, default 1) – Standard deviation of the initialized weights.
hidden_dim (int, default 10) – Number of neurons in the hidden layer.
activation (nn.Module, default
nn.PReLU) – The nonlinear activation function.
Attributes
nn.Module instance of the neural network architecture.
Methods
predict(X)Generate the effect given the observations of the parent nodes.