{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate a Simulated Data Tensor\n",
"\n",
"This notebook demonstrates how to generate a simulated data tensor using the `simulated_sparse_tensor()` function of the `Barnacle` library. In this example we generate a mode-3 data tensor from 5 sparse simulated components. We use the `barnacle.visualize_3d_tensor()` function to visualize the simulated data tensor, as well as the components used to generate the simulation, and we use the `barnacle.plot_factors_heatmap()` to visualize the factors matrices used to generate the simulation components."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# imports\n",
"\n",
"import numpy as np\n",
"import scipy\n",
"import tensorly as tl\n",
"from barnacle import (\n",
" visualize_3d_tensor, \n",
" simulated_sparse_tensor, \n",
" plot_factors_heatmap\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5\n",
"5\n",
"5\n"
]
}
],
"source": [
"# generate simulated data tensor\n",
"\n",
"true_rank = 5\n",
"true_shape = [15, 20, 10]\n",
"true_densities = [.4, .2, .6]\n",
"\n",
"# re-seed simulated data until all factor matrices are full rank\n",
"full_rank = False\n",
"while not full_rank:\n",
" # generate simulated tensor\n",
" sim_tensor = simulated_sparse_tensor(\n",
" shape=true_shape, \n",
" rank=true_rank, \n",
" densities=true_densities, \n",
" factor_dist_list=[\n",
" scipy.stats.uniform(), \n",
" scipy.stats.uniform(loc=-1, scale=2), \n",
" scipy.stats.uniform()\n",
" ], \n",
" random_state=9481\n",
" )\n",
" # check that all factors are full rank\n",
" full_rank = np.all([np.linalg.matrix_rank(factor) == true_rank for factor in sim_tensor.factors])\n",
"\n",
"# ensure that factor matrices are full rank\n",
"for factor in sim_tensor.factors:\n",
" print(np.linalg.matrix_rank(factor))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
" \n",
" "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.plotly.v1+json": {
"config": {
"plotlyServerURL": "https://plot.ly"
},
"data": [
{
"hovertemplate": "axis0=%{x}
axis1=%{y}
axis2=%{z}
abs_exp=%{marker.size}
abundance=%{marker.color}