{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate Model Performance\n",
"\n",
"This notebook demonstrates how to evaluate the performance of a sparse tensor decomposition model fit to simulated data by comparing the model components to the ground truth components used to generate the simulation. Model performance is evaluated on the basis of five metrics: \n",
"\n",
"1. Relative sum of squared errors (SSE)\n",
" - Measures how closely the model matches the data\n",
"1. Factor match score (FMS)\n",
" - Measures how closely two component matrices match one another\n",
"1. Precision\n",
" - Proportion of test cluster membership that adheres to the ground truth\n",
"1. Recall\n",
" - Proportion of ground truth membership recapitulated by test clusters\n",
"1. F1 score\n",
" - Harmonic mean of precision and recall\n",
" \n",
"SSE and FMS consider the model as a whole, whereas precision, recall, and F1 score deal with clusters derived from the factor matrices of a particular mode. In this case we derive clusters from mode-0, and determine cluster membership using the indices corresponding to non-zero weights. Note that the precision and recall metrics implemented in Barnacle are calculated according to the formula proposed by [Saelens et al. (2018)](https://www.nature.com/articles/s41467-018-03424-4), in order to accommodate overlapping clusters. For more on these metrics and their usage, see the article published alongside the Barnacle library. \n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import scipy\n",
"import seaborn as sns\n",
"import tensorly as tl\n",
"import tlviz\n",
"from barnacle import (\n",
" SparseCP, \n",
" visualize_3d_tensor, \n",
" simulated_sparse_tensor, \n",
" pairs_precision_recall\n",
")\n",
"\n",
"# helper function to calculate f1 score from composite precision & recall scores\n",
"def composite_f1(precision, recall):\n",
" '''\n",
" Calculates F1 score from precision and recall.'''\n",
" numerator = precision + recall\n",
" if numerator == 0:\n",
" return 0\n",
" else:\n",
" return (2 * precision * recall) / numerator\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"hovertemplate": "axis0=%{x}
axis1=%{y}
axis2=%{z}
abs_exp=%{marker.size}
abundance=%{marker.color}
| \n", " | Rank | \n", "Sparsity Coefficient | \n", "Replicate | \n", "SSE | \n", "FMS | \n", "Precision | \n", "Recall | \n", "F1 | \n", "
|---|---|---|---|---|---|---|---|---|
| 0 | \n", "1 | \n", "0.2 | \n", "0 | \n", "0.384472 | \n", "0.870112 | \n", "0.439394 | \n", "0.852941 | \n", "0.580000 | \n", "
| 1 | \n", "1 | \n", "0.2 | \n", "1 | \n", "0.397529 | \n", "0.910833 | \n", "0.439394 | \n", "0.852941 | \n", "0.580000 | \n", "
| 2 | \n", "1 | \n", "0.2 | \n", "2 | \n", "0.384020 | \n", "0.848842 | \n", "0.439394 | \n", "0.852941 | \n", "0.580000 | \n", "
| 3 | \n", "1 | \n", "0.2 | \n", "3 | \n", "0.390881 | \n", "0.868323 | \n", "0.436364 | \n", "0.705882 | \n", "0.539326 | \n", "
| 4 | \n", "1 | \n", "0.2 | \n", "4 | \n", "0.400358 | \n", "0.869343 | \n", "0.436364 | \n", "0.705882 | \n", "0.539326 | \n", "
| ... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
| 65 | \n", "5 | \n", "2.0 | \n", "0 | \n", "0.506489 | \n", "0.135875 | \n", "1.000000 | \n", "0.029412 | \n", "0.057143 | \n", "
| 66 | \n", "5 | \n", "2.0 | \n", "1 | \n", "0.509646 | \n", "0.147198 | \n", "1.000000 | \n", "0.029412 | \n", "0.057143 | \n", "
| 67 | \n", "5 | \n", "2.0 | \n", "2 | \n", "0.485476 | \n", "0.160329 | \n", "1.000000 | \n", "0.029412 | \n", "0.057143 | \n", "
| 68 | \n", "5 | \n", "2.0 | \n", "3 | \n", "0.512247 | \n", "0.136309 | \n", "1.000000 | \n", "0.029412 | \n", "0.057143 | \n", "
| 69 | \n", "5 | \n", "2.0 | \n", "4 | \n", "0.535860 | \n", "0.139017 | \n", "1.000000 | \n", "0.029412 | \n", "0.057143 | \n", "
70 rows × 8 columns
\n", "