{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Select Parameters with Cross Validation\n",
"\n",
"This notebook demonstrates how to select parameters of best fit for a sparse tensor decomposition model using a cross-validation strategy. \n",
"\n",
"Fitting a sparse tensor decomposition model requires tuning two key parameters: the number of components (rank) of the model, and the sparsity coefficient (lambda) applied to each mode of the model. These parameters are rarely known a priori, and instead must be inferred. We've developed a cross-validation protocol for selecting parameters of best fit. This protocol relies on measurement replicates to generate three replicate data tensors of equivalent shapes. These replicate data tensors can be used to calculate cross-validated sum of squared error (SSE) and factor match score (FMS) metrics that can be used to identify the parameters of best fit.\n",
"\n",
"We will first simulate data replicates by generating three identical simulated data tensors -- `rep_a`, `rep_b`, and `rep_c` -- with independent noise added to each tensor to simulate variation between replicates. Next we will identify the best fit number of components using cross validated SSE. Locking in this rank, we'll fine tune the best fit sparsity coefficient using cross validated FMS."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scipy\n",
"import seaborn as sns\n",
"import tensorly as tl\n",
"import tlviz\n",
"from barnacle import (\n",
" SparseCP, \n",
" visualize_3d_tensor, \n",
" simulated_sparse_tensor, \n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Replicate A\n"
]
},
{
"data": {
"text/html": [
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"hovertemplate": "axis0=%{x}
axis1=%{y}
axis2=%{z}
abs_exp=%{marker.size}
abundance=%{marker.color}
| \n", " | Rank | \n", "Fitting Replicate | \n", "Comparison Replicate | \n", "Comparison | \n", "SSE | \n", "
|---|---|---|---|---|---|
| 0 | \n", "1 | \n", "A | \n", "A | \n", "fitting | \n", "0.657666 | \n", "
| 1 | \n", "1 | \n", "A | \n", "B | \n", "cross-validation | \n", "0.707090 | \n", "
| 2 | \n", "1 | \n", "A | \n", "C | \n", "cross-validation | \n", "0.714235 | \n", "
| 3 | \n", "1 | \n", "B | \n", "A | \n", "cross-validation | \n", "0.674137 | \n", "
| 4 | \n", "1 | \n", "B | \n", "B | \n", "fitting | \n", "0.691348 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 67 | \n", "8 | \n", "B | \n", "B | \n", "fitting | \n", "0.352665 | \n", "
| 68 | \n", "8 | \n", "B | \n", "C | \n", "cross-validation | \n", "0.648890 | \n", "
| 69 | \n", "8 | \n", "C | \n", "A | \n", "cross-validation | \n", "0.645520 | \n", "
| 70 | \n", "8 | \n", "C | \n", "B | \n", "cross-validation | \n", "0.646472 | \n", "
| 71 | \n", "8 | \n", "C | \n", "C | \n", "fitting | \n", "0.341968 | \n", "
72 rows × 5 columns
\n", "| \n", " | Sparsity Coefficient | \n", "Fitting Replicate | \n", "Comparison Replicate | \n", "SSE | \n", "FMS | \n", "
|---|---|---|---|---|---|
| 0 | \n", "0.05 | \n", "A | \n", "B | \n", "0.563678 | \n", "0.621984 | \n", "
| 1 | \n", "0.05 | \n", "A | \n", "C | \n", "0.575991 | \n", "0.587359 | \n", "
| 2 | \n", "0.05 | \n", "B | \n", "A | \n", "0.573267 | \n", "NaN | \n", "
| 3 | \n", "0.05 | \n", "B | \n", "C | \n", "0.561255 | \n", "0.665616 | \n", "
| 4 | \n", "0.05 | \n", "C | \n", "A | \n", "0.549828 | \n", "NaN | \n", "
| 5 | \n", "0.05 | \n", "C | \n", "B | \n", "0.543003 | \n", "NaN | \n", "
| 6 | \n", "0.10 | \n", "A | \n", "B | \n", "0.569468 | \n", "0.621996 | \n", "
| 7 | \n", "0.10 | \n", "A | \n", "C | \n", "0.563350 | \n", "0.559078 | \n", "
| 8 | \n", "0.10 | \n", "B | \n", "A | \n", "0.553518 | \n", "NaN | \n", "
| 9 | \n", "0.10 | \n", "B | \n", "C | \n", "0.568633 | \n", "0.580420 | \n", "
| 10 | \n", "0.10 | \n", "C | \n", "A | \n", "0.546041 | \n", "NaN | \n", "
| 11 | \n", "0.10 | \n", "C | \n", "B | \n", "0.538591 | \n", "NaN | \n", "
| 12 | \n", "0.20 | \n", "A | \n", "B | \n", "0.565581 | \n", "0.601660 | \n", "
| 13 | \n", "0.20 | \n", "A | \n", "C | \n", "0.558382 | \n", "0.559783 | \n", "
| 14 | \n", "0.20 | \n", "B | \n", "A | \n", "0.542834 | \n", "NaN | \n", "
| 15 | \n", "0.20 | \n", "B | \n", "C | \n", "0.533984 | \n", "0.773011 | \n", "
| 16 | \n", "0.20 | \n", "C | \n", "A | \n", "0.544073 | \n", "NaN | \n", "
| 17 | \n", "0.20 | \n", "C | \n", "B | \n", "0.534930 | \n", "NaN | \n", "
| 18 | \n", "0.30 | \n", "A | \n", "B | \n", "0.564586 | \n", "0.603628 | \n", "
| 19 | \n", "0.30 | \n", "A | \n", "C | \n", "0.558994 | \n", "0.558960 | \n", "
| 20 | \n", "0.30 | \n", "B | \n", "A | \n", "0.540832 | \n", "NaN | \n", "
| 21 | \n", "0.30 | \n", "B | \n", "C | \n", "0.532331 | \n", "0.783579 | \n", "
| 22 | \n", "0.30 | \n", "C | \n", "A | \n", "0.544599 | \n", "NaN | \n", "
| 23 | \n", "0.30 | \n", "C | \n", "B | \n", "0.534873 | \n", "NaN | \n", "
| 24 | \n", "0.40 | \n", "A | \n", "B | \n", "0.567422 | \n", "0.601039 | \n", "
| 25 | \n", "0.40 | \n", "A | \n", "C | \n", "0.562392 | \n", "0.555470 | \n", "
| 26 | \n", "0.40 | \n", "B | \n", "A | \n", "0.541534 | \n", "NaN | \n", "
| 27 | \n", "0.40 | \n", "B | \n", "C | \n", "0.533658 | \n", "0.789679 | \n", "
| 28 | \n", "0.40 | \n", "C | \n", "A | \n", "0.548292 | \n", "NaN | \n", "
| 29 | \n", "0.40 | \n", "C | \n", "B | \n", "0.537406 | \n", "NaN | \n", "
| 30 | \n", "0.50 | \n", "A | \n", "B | \n", "0.571722 | \n", "0.590569 | \n", "
| 31 | \n", "0.50 | \n", "A | \n", "C | \n", "0.567495 | \n", "0.550812 | \n", "
| 32 | \n", "0.50 | \n", "B | \n", "A | \n", "0.549326 | \n", "NaN | \n", "
| 33 | \n", "0.50 | \n", "B | \n", "C | \n", "0.565389 | \n", "0.619288 | \n", "
| 34 | \n", "0.50 | \n", "C | \n", "A | \n", "0.554570 | \n", "NaN | \n", "
| 35 | \n", "0.50 | \n", "C | \n", "B | \n", "0.542781 | \n", "NaN | \n", "
| 36 | \n", "0.60 | \n", "A | \n", "B | \n", "0.577466 | \n", "0.577193 | \n", "
| 37 | \n", "0.60 | \n", "A | \n", "C | \n", "0.574194 | \n", "0.542921 | \n", "
| 38 | \n", "0.60 | \n", "B | \n", "A | \n", "0.555087 | \n", "NaN | \n", "
| 39 | \n", "0.60 | \n", "B | \n", "C | \n", "0.570769 | \n", "0.620740 | \n", "
| 40 | \n", "0.60 | \n", "C | \n", "A | \n", "0.561618 | \n", "NaN | \n", "
| 41 | \n", "0.60 | \n", "C | \n", "B | \n", "0.549689 | \n", "NaN | \n", "
| 42 | \n", "0.80 | \n", "A | \n", "B | \n", "0.594029 | \n", "0.437480 | \n", "
| 43 | \n", "0.80 | \n", "A | \n", "C | \n", "0.592781 | \n", "0.523104 | \n", "
| 44 | \n", "0.80 | \n", "B | \n", "A | \n", "0.606608 | \n", "NaN | \n", "
| 45 | \n", "0.80 | \n", "B | \n", "C | \n", "0.617170 | \n", "0.465549 | \n", "
| 46 | \n", "0.80 | \n", "C | \n", "A | \n", "0.577298 | \n", "NaN | \n", "
| 47 | \n", "0.80 | \n", "C | \n", "B | \n", "0.566667 | \n", "NaN | \n", "
| 48 | \n", "1.00 | \n", "A | \n", "B | \n", "0.635231 | \n", "0.313387 | \n", "
| 49 | \n", "1.00 | \n", "A | \n", "C | \n", "0.636494 | \n", "0.521444 | \n", "
| 50 | \n", "1.00 | \n", "B | \n", "A | \n", "0.627666 | \n", "NaN | \n", "
| 51 | \n", "1.00 | \n", "B | \n", "C | \n", "0.639601 | \n", "0.454012 | \n", "
| 52 | \n", "1.00 | \n", "C | \n", "A | \n", "0.597917 | \n", "NaN | \n", "
| 53 | \n", "1.00 | \n", "C | \n", "B | \n", "0.590343 | \n", "NaN | \n", "