{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we will load a dataset from `scikit-learn` and use it to create a custom `Dataset` object in _Olympus_." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from olympus import Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# load the boston dataset from sklearn\n", "from sklearn.datasets import load_boston\n", "boston = load_boston()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CRIMZNINDUSCHASNOXRMAGEDISRADTAXPTRATIOBLSTATtarget
00.0063218.02.310.00.5386.57565.24.09001.0296.015.3396.904.9824.0
10.027310.07.070.00.4696.42178.94.96712.0242.017.8396.909.1421.6
20.027290.07.070.00.4697.18561.14.96712.0242.017.8392.834.0334.7
30.032370.02.180.00.4586.99845.86.06223.0222.018.7394.632.9433.4
40.069050.02.180.00.4587.14754.26.06223.0222.018.7396.905.3336.2
.............................................
5010.062630.011.930.00.5736.59369.12.47861.0273.021.0391.999.6722.4
5020.045270.011.930.00.5736.12076.72.28751.0273.021.0396.909.0820.6
5030.060760.011.930.00.5736.97691.02.16751.0273.021.0396.905.6423.9
5040.109590.011.930.00.5736.79489.32.38891.0273.021.0393.456.4822.0
5050.047410.011.930.00.5736.03080.82.50501.0273.021.0396.907.8811.9
\n", "

506 rows × 14 columns

\n", "
" ], "text/plain": [ " CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX \\\n", "0 0.00632 18.0 2.31 0.0 0.538 6.575 65.2 4.0900 1.0 296.0 \n", "1 0.02731 0.0 7.07 0.0 0.469 6.421 78.9 4.9671 2.0 242.0 \n", "2 0.02729 0.0 7.07 0.0 0.469 7.185 61.1 4.9671 2.0 242.0 \n", "3 0.03237 0.0 2.18 0.0 0.458 6.998 45.8 6.0622 3.0 222.0 \n", "4 0.06905 0.0 2.18 0.0 0.458 7.147 54.2 6.0622 3.0 222.0 \n", ".. ... ... ... ... ... ... ... ... ... ... \n", "501 0.06263 0.0 11.93 0.0 0.573 6.593 69.1 2.4786 1.0 273.0 \n", "502 0.04527 0.0 11.93 0.0 0.573 6.120 76.7 2.2875 1.0 273.0 \n", "503 0.06076 0.0 11.93 0.0 0.573 6.976 91.0 2.1675 1.0 273.0 \n", "504 0.10959 0.0 11.93 0.0 0.573 6.794 89.3 2.3889 1.0 273.0 \n", "505 0.04741 0.0 11.93 0.0 0.573 6.030 80.8 2.5050 1.0 273.0 \n", "\n", " PTRATIO B LSTAT target \n", "0 15.3 396.90 4.98 24.0 \n", "1 17.8 396.90 9.14 21.6 \n", "2 17.8 392.83 4.03 34.7 \n", "3 18.7 394.63 2.94 33.4 \n", "4 18.7 396.90 5.33 36.2 \n", ".. ... ... ... ... \n", "501 21.0 391.99 9.67 22.4 \n", "502 21.0 396.90 9.08 20.6 \n", "503 21.0 396.90 5.64 23.9 \n", "504 21.0 393.45 6.48 22.0 \n", "505 21.0 396.90 7.88 11.9 \n", "\n", "[506 rows x 14 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# concatenate the features and targets into single lists/arrays and use the to create a pandas dataframe\n", "data = np.c_[boston['data'], boston['target']]\n", "columns = list(boston['feature_names'])\n", "columns.append('target')\n", "\n", "df = pd.DataFrame(data=data, columns=columns)\n", "df" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# pass the Dataframe as the data argument for Dataset and specify which one is the target variable\n", "dataset = Dataset(data=df, target_ids=['target'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now `dataset` is an instance of the _Olympus_ class `Dataset`. However, before we can use it to train a custom `Emulator`, we need to specicify the parameter space for this dataset/problem." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from olympus import ParameterSpace, Parameter\n", "\n", "# initialise a parameter space object\n", "param_space = ParameterSpace()\n", "\n", "# add all features in the dataset as a variable in the parameter space\n", "for feature in dataset.features:\n", " low = np.min(dataset.data[feature]) # take the min in the data\n", " high = np.max(dataset.data[feature]) # take the max in the data\n", " param = Parameter(kind='continuous', name=feature, low=low, high=high)\n", " param_space.add(param)\n", " \n", "dataset.set_param_space(param_space)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in the above code we set the bounds of the parameters based on the min/max samples in the dataset. This can also be achieved by using the `infer_param_space` method of `Dataset`, as follows:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "dataset.infer_param_space()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, most often you will want these bounds to depend on the details your problem, in which case you can explicitly specify the bounds for all parameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we define a small Bayesian Neural Network and we will test its performance in emulating this dataset. Note that, by default, `Dataset` creates 5 random folds for cross validation and reserves 20% of the data for testing." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from olympus import Emulator\n", "from olympus.models import BayesNeuralNet\n", "\n", "mymodel = BayesNeuralNet(hidden_depth=2, hidden_nodes=12, hidden_act='leaky_relu', out_act=\"relu\", \n", " batch_size=50, reg=0.005, max_epochs=10000)\n", "emulator = Emulator(dataset=dataset, model=mymodel, feature_transform='normalize', target_transform='normalize')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0;37m[INFO] >>> Training model on 80% of the dataset, testing on 20%...\n", "\u001b[0mWARNING:tensorflow:From /Users/Matteo/anaconda2/envs/olympus/lib/python3.7/site-packages/tensorflow_probability/python/layers/util.py:104: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use `layer.add_weight` method instead.\n", "WARNING:tensorflow:From /Users/Matteo/anaconda2/envs/olympus/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "If using Keras pass *_constraint arguments to layers.\n", "\u001b[0;37m[INFO] =======================================================================\n", "\u001b[0m\u001b[0;37m[INFO] Epoch Train R2 Train RMSD Test R2 Test RMSD\n", "\u001b[0m\u001b[0;37m[INFO] =======================================================================\n", "\u001b[0m\u001b[0;37m[INFO] 0 -4.011 0.483 -5.083 0.404 *\n", "\u001b[0m\u001b[0;37m[INFO] 100 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 200 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 300 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 400 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 500 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 600 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 700 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 800 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 900 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 1000 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 1100 -4.011 0.483 -5.083 0.404\n", "\u001b[0m\u001b[0;37m[INFO] 1200 -3.730 0.470 -4.805 0.394 *\n", "\u001b[0m\u001b[0;37m[INFO] 1300 0.392 0.168 0.536 0.111 *\n", "\u001b[0m\u001b[0;37m[INFO] 1400 0.579 0.140 0.655 0.096 *\n", "\u001b[0m\u001b[0;37m[INFO] 1500 0.617 0.134 0.721 0.086 *\n", "\u001b[0m\u001b[0;37m[INFO] 1600 0.696 0.119 0.753 0.081 *\n", "\u001b[0m\u001b[0;37m[INFO] 1700 0.665 0.125 0.788 0.075 *\n", "\u001b[0m\u001b[0;37m[INFO] 1800 0.681 0.122 0.796 0.074 *\n", "\u001b[0m\u001b[0;37m[INFO] 1900 0.704 0.118 0.795 0.074\n", "\u001b[0m\u001b[0;37m[INFO] 2000 0.769 0.104 0.787 0.076\n", "\u001b[0m\u001b[0;37m[INFO] 2100 0.792 0.099 0.796 0.074\n", "\u001b[0m\u001b[0;37m[INFO] 2200 0.795 0.098 0.800 0.073 *\n", "\u001b[0m\u001b[0;37m[INFO] 2300 0.807 0.095 0.782 0.076\n", "\u001b[0m\u001b[0;37m[INFO] 2400 0.775 0.102 0.829 0.068 *\n", "\u001b[0m\u001b[0;37m[INFO] 2500 0.801 0.096 0.805 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 2600 0.836 0.088 0.808 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 2700 0.834 0.088 0.807 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 2800 0.825 0.090 0.812 0.071\n", "\u001b[0m\u001b[0;37m[INFO] 2900 0.841 0.086 0.796 0.074\n", "\u001b[0m\u001b[0;37m[INFO] 3000 0.830 0.089 0.805 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 3100 0.842 0.086 0.821 0.069\n", "\u001b[0m\u001b[0;37m[INFO] 3200 0.848 0.084 0.824 0.069\n", "\u001b[0m\u001b[0;37m[INFO] 3300 0.843 0.086 0.805 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 3400 0.862 0.080 0.811 0.071\n", "\u001b[0m\u001b[0;37m[INFO] 3500 0.861 0.080 0.794 0.074\n", "\u001b[0m\u001b[0;37m[INFO] 3600 0.880 0.075 0.796 0.074\n", "\u001b[0m\u001b[0;37m[INFO] 3700 0.872 0.077 0.805 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 3800 0.883 0.074 0.806 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 3900 0.881 0.074 0.810 0.071\n", "\u001b[0m\u001b[0;37m[INFO] 4000 0.877 0.076 0.816 0.070\n", "\u001b[0m\u001b[0;37m[INFO] 4100 0.882 0.074 0.824 0.069\n", "\u001b[0m\u001b[0;37m[INFO] 4200 0.887 0.073 0.815 0.070\n", "\u001b[0m\u001b[0;37m[INFO] 4300 0.882 0.074 0.815 0.070\n", "\u001b[0m\u001b[0;37m[INFO] 4400 0.861 0.080 0.806 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 4500 0.888 0.072 0.813 0.071\n", "\u001b[0m\u001b[0;37m[INFO] 4600 0.889 0.072 0.809 0.071\n", "\u001b[0m\u001b[0;37m[INFO] 4700 0.901 0.068 0.820 0.069\n", "\u001b[0m\u001b[0;37m[INFO] 4800 0.887 0.073 0.825 0.068\n", "\u001b[0m\u001b[0;37m[INFO] 4900 0.899 0.069 0.805 0.072\n", "\u001b[0m\u001b[0;37m[INFO] 5000 0.909 0.065 0.819 0.070\n", "\u001b[0m\u001b[0;37m[INFO] 5100 0.906 0.066 0.822 0.069\n", "\u001b[0m\u001b[0;37m[INFO] 5200 0.919 0.061 0.827 0.068\n", "\u001b[0m\u001b[0;37m[INFO] 5300 0.914 0.063 0.824 0.069\n", "\u001b[0m\u001b[0;37m[INFO] 5400 0.917 0.062 0.833 0.067 *\n", "\u001b[0m\u001b[0;37m[INFO] 5500 0.915 0.063 0.835 0.066 *\n", "\u001b[0m\u001b[0;37m[INFO] 5600 0.920 0.061 0.846 0.064 *\n", "\u001b[0m\u001b[0;37m[INFO] 5700 0.926 0.059 0.851 0.063 *\n", "\u001b[0m\u001b[0;37m[INFO] 5800 0.927 0.058 0.844 0.065\n", "\u001b[0m\u001b[0;37m[INFO] 5900 0.929 0.058 0.851 0.063\n", "\u001b[0m\u001b[0;37m[INFO] 6000 0.922 0.060 0.846 0.064\n", "\u001b[0m\u001b[0;37m[INFO] 6100 0.932 0.057 0.848 0.064\n", "\u001b[0m\u001b[0;37m[INFO] 6200 0.931 0.057 0.861 0.061 *\n", "\u001b[0m\u001b[0;37m[INFO] 6300 0.929 0.057 0.863 0.061 *\n", "\u001b[0m\u001b[0;37m[INFO] 6400 0.930 0.057 0.866 0.060 *\n", "\u001b[0m\u001b[0;37m[INFO] 6500 0.934 0.056 0.865 0.060\n", "\u001b[0m\u001b[0;37m[INFO] 6600 0.930 0.057 0.857 0.062\n", "\u001b[0m\u001b[0;37m[INFO] 6700 0.930 0.057 0.878 0.057 *\n", "\u001b[0m\u001b[0;37m[INFO] 6800 0.935 0.055 0.859 0.061\n", "\u001b[0m\u001b[0;37m[INFO] 6900 0.930 0.057 0.876 0.058\n", "\u001b[0m\u001b[0;37m[INFO] 7000 0.930 0.057 0.890 0.054 *\n", "\u001b[0m\u001b[0;37m[INFO] 7100 0.927 0.058 0.882 0.056\n", "\u001b[0m\u001b[0;37m[INFO] 7200 0.938 0.054 0.872 0.059\n", "\u001b[0m\u001b[0;37m[INFO] 7300 0.935 0.055 0.870 0.059\n", "\u001b[0m\u001b[0;37m[INFO] 7400 0.932 0.056 0.887 0.055\n", "\u001b[0m\u001b[0;37m[INFO] 7500 0.936 0.055 0.872 0.059\n", "\u001b[0m\u001b[0;37m[INFO] 7600 0.928 0.058 0.886 0.055\n", "\u001b[0m\u001b[0;37m[INFO] 7700 0.934 0.055 0.876 0.058\n", "\u001b[0m\u001b[0;37m[INFO] 7800 0.946 0.050 0.866 0.060\n", "\u001b[0m\u001b[0;37m[INFO] 7900 0.944 0.051 0.880 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 8000 0.941 0.052 0.878 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 8100 0.942 0.052 0.884 0.056\n", "\u001b[0m\u001b[0;37m[INFO] 8200 0.934 0.056 0.873 0.058\n", "\u001b[0m\u001b[0;37m[INFO] 8300 0.939 0.053 0.879 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 8400 0.935 0.055 0.872 0.059\n", "\u001b[0m\u001b[0;37m[INFO] 8500 0.941 0.053 0.879 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 8600 0.943 0.051 0.890 0.054 *\n", "\u001b[0m\u001b[0;37m[INFO] 8700 0.931 0.057 0.896 0.053 *\n", "\u001b[0m\u001b[0;37m[INFO] 8800 0.936 0.055 0.889 0.055\n", "\u001b[0m\u001b[0;37m[INFO] 8900 0.934 0.056 0.881 0.056\n", "\u001b[0m\u001b[0;37m[INFO] 9000 0.942 0.052 0.880 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 9100 0.940 0.053 0.879 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 9200 0.938 0.054 0.877 0.057\n", "\u001b[0m\u001b[0;37m[INFO] 9300 0.944 0.051 0.883 0.056\n", "\u001b[0m\u001b[0;37m[INFO] 9400 0.937 0.054 0.882 0.056\n", "\u001b[0m\u001b[0;37m[INFO] 9500 0.932 0.056 0.904 0.051 *\n", "\u001b[0m\u001b[0;37m[INFO] 9600 0.942 0.052 0.889 0.055\n", "\u001b[0m\u001b[0;37m[INFO] 9700 0.946 0.050 0.885 0.055\n", "\u001b[0m\u001b[0;37m[INFO] 9800 0.944 0.051 0.890 0.054\n", "\u001b[0m\u001b[0;37m[INFO] 9900 0.939 0.053 0.883 0.056\n", "\u001b[0m\u001b[0;37m[INFO] Training completed in 10.18 seconds.\n", "\u001b[0m\u001b[0;37m[INFO] ===========================================================================\n", "\n", "\u001b[0m\u001b[0;37m[INFO] Train R2 Score: 0.9322\n", "\u001b[0m\u001b[0;37m[INFO] Test R2 Score: 0.9044\n", "\u001b[0m\u001b[0;37m[INFO] Train RMSD Score: 0.0562\n", "\u001b[0m\u001b[0;37m[INFO] Test RMSD Score: 0.0506\n", "\u001b[0m" ] }, { "data": { "text/plain": [ "{'train_r2': 0.932159417784379,\n", " 'test_r2': 0.9044223215989219,\n", " 'train_rmsd': 0.05624565811984018,\n", " 'test_rmsd': 0.05059158705037693}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emulator.train()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's now say you would like to share this dataset with the community by uploading it to the _Olympus Datasets_. You can do this with the `upload` command line tool in _Olympus_ as described in the documentation. However, you first need to prepare the dataset in the expected format. One way to easily do this is to use the `to_disk` method available to `Dataset` objects." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# save dataset to disk\n", "dataset.to_disk('custom_dataset')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "config.json data.csv description.txt\n" ] } ], "source": [ "!ls custom_dataset/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "olympus", "language": "python", "name": "olympus" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }