Skip to content

Lightweight framework to make neural networks

Notifications You must be signed in to change notification settings

xEcho1337/brain4j

Repository files navigation

Brain4J

Brain4J

Brain4J is a powerful, lightweight, and easy-to-use Machine Learning library written in Java, designed for speed and simplicity.


Getting Started

When building a neural network, you have many options. In this example, we will create a neural network to simulate an XOR gate.

Defining the Model

To represent an XOR gate, we can use a simple neural network with four layers:

Model network = new Model(
        new DenseLayer(2, Activations.LINEAR),
        new DenseLayer(16, Activations.RELU),
        new DenseLayer(16, Activations.RELU),
        new DenseLayer(1, Activations.SIGMOID)
);

Compiling the Model

Next, define the weight initialization method and the loss function for training. Use the compile method as follows:

network.compile(
        WeightInitialization.HE,
        LossFunctions.BINARY_CROSS_ENTROPY,
        new Adam(0.1),
        new StochasticUpdater()
);

For models with a single output neuron (producing values between 0 and 1), Binary Cross Entropy is the recommended loss function, paired with the Adam optimizer.

Also, when using the ReLU activation function it's suggested to use the He weight initialization for better results.

Preparing Training Data

Create your training dataset using DataSet and DataRow:

DataRow first = new DataRow(Vector.of(0, 0), Vector.of(0));
DataRow second = new DataRow(Vector.of(0, 1), Vector.of(1));
DataRow third = new DataRow(Vector.of(1, 0), Vector.of(1));
DataRow fourth = new DataRow(Vector.of(1, 1), Vector.of(0));

DataSet training = new DataSet(first, second, third, fourth);

Training the Model

Once the setup is complete, use the fit method inside a loop to train the network. Training stops when the error is below a certain threshold.

Tip: Always split your dataset into training and testing sets to evaluate the model’s performance.
double error;

do {
    network.fit(training, 1);
    error = network.evaluate(training);
} while (error > 0.01);

The above code trains the neural network with a learning rate of 0.001, stopping when the error falls below 1%.

Contributing & Contact

Contributions are always welcome via pull requests or issue reports.

  • Telegram: @nettyfan
  • Discord: @xecho1337

About

Lightweight framework to make neural networks

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages