2. Set up training config¶
pseudodynamics+ provides two ways of passing arguments when runing with a script.
One is the to pass -a or --arg. The other one save different args a config file to set up the training for tracing differernt models and better reproduciibility.
In this notebook, we will
go through different arguments
create a config
jsonfilehow to generate an
ExperimetnConfigobject using the a config file
%load_ext autoreload
%autoreload 2
import os, sys
# if sys.platform.startswith("darwin"):
# os.environ['KMP_DUPLICATE_LIB_OK']='True'
import json
import pseudodynamics as pdp
os.chdir(pdp.main_dir)
Basic args¶
basic_config = {
"config": None,
"dataset": "tom_pos", # the dataset prefix of the h5ad file, i.e. "data/tom_pos.h5ad"
"log_name": "tom_pos_fulltime", # the name of the logging directory ( logs/`log_name` )
"progress_bar": True
}
model_config = {
# Model choice
"model" : "pde_params", # pseudodynamics model , `log_pde_params` as another option
"time_sensitive": True, # True: parameters time and state dependent. False : parameters time independent
# Neural network
"n_dimension": 10, # input size and the dimension for density estimation
"channels": "64,64", # the hidden layer size, e.g. "64,64" means 4 layer density network [input, 64, 64, output]
"lr": 0.0003, # learning rate
"schedule_lr": "CyclicLR", # learning rate scheduler
# dynamics equation precision
"tol": 0.0001, # tolerance for NeuralODE integral , atol = tol, rtol = tol
"time_scale_factor" : 1, # factor scaling the time of integral for NeuralODE, smaller factor -> longer integral
"pretrained": None, # resume
"gpu_devices": 0, # which GPU to use, set to None for CPU training
# Loss term related weights
"weight_intensity": None, # the weight to emphasize the high density cell, > 1 for weighting, <1 for unweighting
"deltax_weight": 0.01, # the weight used to inform v with local state transition, which is the similarity of deltax and v
"R_weight": 1, # the weight to balance PDE residue loss and the data-related loss
"growth_weight": None, # the weight to regularize the contribution of growth to overall density gain, greater means harder boundary
"D_penalty" : 1, # the level of restricting Diffusion
}
An important arg is
dataset_config = {
"cellstate_key": "DM_scaled", # obsm key used as cell state
"deltax_key": "Delta_DM", # obsm key used for local cell state changes
"timepoint_idx": [
0,
1,
2,
3,
4,
6,
8
], # the timepoints to use , numeric index
"knn_volume": False,
"batch_size": 50,
"bw": None,
"norm_time": False,
}
raw_args = {}
raw_args.update(basic_config)
raw_args.update(model_config)
raw_args.update(dataset_config)
configs = {"raw_args": raw_args}
Failed to connect to the remote Jupyter Server 'http://10.43.76.51:9416/'. Verify the server is running and reachable. (Failed to connect to the remote Jupyter Server 'http://10.43.76.51:9416/'. Verify the server is running and reachable. (request to http://10.43.76.51:9416/api/kernels?1765701221573 failed, reason: connect ECONNREFUSED 10.43.76.51:9416).).
with open('logs/testing_config.json', 'w') as f:
json.dump(configs, f, indent=4)
instanize a Config object¶
test_config = pdp.ExperimentConfig(config='testing_config.json')
test_config._get_model_config()
{'model_class': 'pde_params',
'channels': '64,64',
'activation_fn': None,
'ode_tol': 0.0001,
'growth_weight': None,
'R_weight': 1,
'D_penalty': 1,
'deltax_weight': 0.01,
'weight_intensity': None,
'time_scale_factor': 1,
'time_sensitive': True,
'v_channels': None,
'g_channels': None,
'D_channels': None}
Training¶
We suggest using the script to train the model. Under the prject main directory, run the following command:
python main_train.py --config logs/testing_config.json -G None
An experimental record will be automatically generated under the log_name directory. For example, the above command will generate a record under the logs/tom_pos_fulltime/pde_params_tsense/V0-config.json. The record file can be used to resume the model and the dataset.
# load record json
v0_config = pdp.ExperimentConfig(config='logs/tom_pos_fulltime/pde_params_tsense/V0_config.json')
# check updated model config
v0_config.model_config
{'model_class': 'pde_params',
'channels': None,
'activation_fn': None,
'ode_tol': 0.0001,
'growth_weight': 0,
'R_weight': 1,
'D_penalty': 1,
'deltax_weight': 0.01,
'weight_intensity': 1,
'time_scale_factor': 1,
'time_sensitive': True,
'v_channels': [11, 64, 64, 10],
'g_channels': [11, 64, 64, 1],
'D_channels': [11, 64, 64, 1]}
# we can locate the checkpoiant by:
v0_config.find_lastest_ckpt()