from dataset import Audio_Dataset
from transforms import Scattering
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin,MIRPlugin,EarlyStoppingPlugin,GenerativeReplayPlugin,LwFPlugin,ReplayPlugin,LRSchedulerPlugin,SynapticIntelligencePlugin,GDumbPlugin,CoPEPlugin,AGEMPlugin
from avalanche.benchmarks.generators import nc_benchmark
from avalanche.evaluation.metrics import forgetting_metrics, accuracy_metrics,loss_metrics, timing_metrics, cpu_usage_metrics, confusion_matrix_metrics, gpu_usage_metrics
from avalanche.training.templates import SupervisedTemplate
from avalanche.training import JointTraining,Naive,OnlineNaive,ICaRL,ClassBalancedBuffer,icarl,MIR
from avalanche.training.supervised import StreamingLDA
from avalanche.benchmarks.scenarios import OnlineCLScenario
from plugins.ekfac_plugin import EKFAC_Plugin,KFAC_Plugin
from nemo.core.optim.optimizers import Novograd
from torch.nn import CrossEntropyLoss,NLLLoss
from torch.optim import SGD,Adam
import torch
import models
from matchbox.ConvASRDecoder import ConvASRDecoderClassification
from sacred import Experiment
from sacred.observers import MongoObserver
from tensorflow.core.util import event_pb2 # Used to interface with
import tensorflow as tf # tensorboard logs
import time
import os
import glob
import logging
import numpy as np
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
from templates.river import RiverTemplate
from river.forest import ARFClassifier
from s_trees.utils.cls_utils import HTtoRIVER
from s_trees.learners.ht import HoeffdingTree
from s_trees.learners.irf import IncrementalRandomForest
from templates.qda import StreamingQDA
from templates.tricks import Supervised_LB_Template,Supervised_SS_Template
#Setting up the experiment
ex=Experiment('Online EKFAC balanced BIG Replay with SS with sgd')
ex.observers.append(MongoObserver(db_name='Continual-learning'))
@ex.config
def cfg():
"""Config function for efficient config saving in sacred
"""
device = torch.device("cuda")
opt_type='novo'
learning_rate=0.1
train_batch_size=128
eval_batch_size=128
train_epochs=1
momentum=0.9
w_decay=0.001
betas=[0.95, 0.5]
seed=2
PolynomialHoldDecayAnnealing_schedule=False
tags = []#"Regularization","MatchboxNet","M5","Joint","Naive","Replay","Combined","Architectural"]#to choose from in Omniboard
save_model=True
[docs]@ex.automain
def run(device,opt_type,learning_rate,train_batch_size,eval_batch_size,train_epochs,momentum,w_decay,betas,save_model,_seed,_run):
"""Main function of the continual learning framework. This uses `Avalanche lib <https://avalanche.continualai.org/>`_ to create continual learning scenarios, and `Sacred <https://github.com/IDSIA/sacred.git>`_
to store the experiments.
.. note::
The Avalanche library uses multiple main concepts to enable continual learning with pytorch:
- **Strategies :** Strategies model the pytorch training loop. One can thus create strategies for special loop cycles and algorithms.
- **Scenarios :** A particular setting, i.e. specificities about the continual stream of data, a continual learning algorithm will face. For example we can have class incremental scenarios or task incremental scenarios.
- **Plugins :** A module designed to simply augment a regular continual strategy with custom behavior. Adding evaluators for example or enabling replay learning.
For more detailed information about the use of this library check out their main `website <https://avalanche.continualai.org/>`_ and their `API <https://avalanche-api.continualai.org/en/v0.3.1/>`_
Args:
opt_type (str): Optimizer type
learning_rate (float): Learning rate
train_batch_size (int): Train mini batch size
eval_batch_size (int): Eval mini batch size
train_epochs (int): Number of training epochs on each experience
momentum (float): Momentum value in optimizer
PolynomialHoldDecayAnnealing_schedule (bool): Enable or not the learning rate scheduler
save_model(bool): Save model as artifact or not
_seed (int): Random seed generated by the sacred experiment. This seed is common to all the used libraries capable of randomness
_run : Sacred runtime environment
Returns:
int: Top1 average accuracy on eval stream.
"""
"""
Choose Model from available models:
Scattering : torch.nn.Sequential( Scattering(),models.EncDecBaseModel(num_mels=50,num_classes=35,final_filter=128,input_length=1000))
MFCC : models.EncDecBaseModel(num_mels=64,num_classes=35,final_filter=128,input_length=1601)
Basic net : M5(n_input=1,n_channel=35)
NB: check the pre-processing before using model
"""
pretrained=torch.nn.Sequential( Scattering(),models.EncDecBaseModel(num_mels=50,num_classes=70,final_filter=128,input_length=1000))
pretrained.load_state_dict(torch.load('./pre_training/models/pretrained2.pt'))
for param in pretrained.parameters():
param.requires_grad = False
model=torch.nn.Sequential( pretrained[0],pretrained[1].encoder,models.Pool(128),torch.nn.Linear(128,35))
#Import dataset
DATASET=Audio_Dataset(train_transformation=None,test_transformation=None)
command_train=DATASET(train=True,pre_process=False,output_shape=[128])
command_test =DATASET(train=False,pre_process=False,output_shape=[128])
#command_train=DATASET.MLCommons(sub_folder="subset2",subset='training')
#command_test=DATASET.MLCommons(sub_folder="subset2",subset='testing')
# Create Scenario
scenario = nc_benchmark(command_train, command_test, n_experiences=7, shuffle=True, seed=_seed,task_labels=False)
# Setup Logging
## log to Tensorboard
tb_logger = TensorboardLogger(tb_log_dir='../tb_data')
## log to text file
text_logger = TextLogger(open('../log.txt', 'a'))
## print to stdout
interactive_logger = InteractiveLogger()
eval_plugin = EvaluationPlugin(
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
timing_metrics(epoch=True, epoch_running=True),
forgetting_metrics(experience=True, stream=True),
cpu_usage_metrics(experience=True),
confusion_matrix_metrics(num_classes=scenario.n_classes, save_image=False,stream=True),
gpu_usage_metrics(gpu_id=0,minibatch=True, epoch=True, experience=True, stream=True),
loggers=[interactive_logger, text_logger,tb_logger]
)
# Initialize the optimizer
if opt_type == 'sgd':
optimizer=SGD(model.parameters(),lr=learning_rate, momentum=momentum,weight_decay=w_decay)
elif opt_type == 'adam':
optimizer=Adam(model.parameters(),lr=learning_rate,betas=betas,weight_decay=w_decay)
elif opt_type == 'novograd':
optimizer=Novograd(model.parameters(),lr=learning_rate,betas=betas,weight_decay=w_decay)
else:
logging.warning("This type of optimizer is not implemented, defaulting to ADAM")
optimizer=Adam(model.parameters(), lr=learning_rate,weight_decay=w_decay,betas=betas)
"""
Initialise plugin list
Here are some examples:
Lr scheduler : LRSchedulerPlugin(PolynomialHoldDecayAnnealing(optimizer=optimizer,power=2.0,max_steps=(int(command_train.__len__()/train_batch_size)+1)*train_epochs,min_lr=0.000001,last_epoch=-1,warmup_ratio=0.05,hold_ratio=0.45))
Replay : ReplayPlugin(mem_size=50)
Regularization : SynapticIntelligencePlugin
Pseudo Replay : GenerativeReplayPlugin()
NB: we can add multiple plugins to the same strategy
"""
scheduler=torch.optim.lr_scheduler.SequentialLR(optimizer=optimizer,
schedulers=[ torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, total_iters=int((int(command_train.__len__()/train_batch_size)+1)*train_epochs*0.4)),
torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=int((int(command_train.__len__()/train_batch_size)+1)*train_epochs*0.6), power=2.0, last_epoch=- 1, verbose=False)],
milestones=[int((int(command_train.__len__()/train_batch_size)+1)*train_epochs*0.4)])
buffer_policy=ClassBalancedBuffer(max_size=20000,adaptive_size=True)
plugin_list=[ReplayPlugin(20000,storage_policy=buffer_policy),EKFAC_Plugin(network=model)]#LRSchedulerPlugin(scheduler=scheduler,step_granularity="iteration"),
"""
Choose the continual learning startegy.
Here are some examples:
Joint training : JointTraining(model, optimizer,CrossEntropyLoss(), train_mb_size=train_batch_size, eval_mb_size=eval_batch_size,device=device,train_epochs=train_epochs,evaluator=eval_plugin,plugins=plugin_list)
Naive : Naive()
Supervised Template : SupervisedTemplate(model, optimizer,CrossEntropyLoss(), train_mb_size=train_batch_size, eval_mb_size=eval_batch_size,device=device,train_epochs=train_epochs,evaluator=eval_plugin,plugins=plugin_list)
Streaming LDA : StreamingLDA(slda_model= model,criterion= CrossEntropyLoss(),input_size=128,num_classes= 35,evaluator=eval_plugin,plugins=plugin_list, device=device,eval_mb_size=eval_batch_size,train_mb_size=train_batch_size)
Online Naive : OnlineNaive(model=model,optimizer=optimizer,criterion=CrossEntropyLoss(),train_mb_size=train_batch_size,eval_mb_size=eval_batch_size,device=device,plugins=plugin_list,evaluator=eval_plugin,eval_every=-1)
River Template : RiverTemplate(deep_model=model,online_model=HTtoRIVER(classifier=IncrementalRandomForest(size=10, num_workers=15, att_split_est=True)),criterion=None,input_size=128,train_mb_size=train_batch_size,eval_mb_size=eval_batch_size,device=device,evaluator=EvaluationPlugin(loggers=[interactive_logger, text_logger, tb_logger]))
"""
cl_strategy= Supervised_SS_Template(model, optimizer,CrossEntropyLoss(), train_mb_size=train_batch_size, eval_mb_size=eval_batch_size,device=device,train_epochs=train_epochs,evaluator=eval_plugin,plugins=plugin_list)
#Training loop
logging.info('Starting experiment...')
results = []
# Check if the user requested Joint training or regular Supervised Template
if isinstance(cl_strategy,JointTraining):
logging.info("Start of joint training: ")
res = cl_strategy.train(scenario.train_stream)
logging.info('Training completed')
# eval
logging.info('Start of Eval')
results.append(cl_strategy.eval(scenario.test_stream))
logging.info('End of Eval')
else:
for experience in scenario.train_stream:
logging.info("Start of experience: "+ str(experience.current_experience))
logging.info("Current Classes: "+ str(experience.classes_in_this_experience))
# train returns a dictionary which contains all the metric values
res = cl_strategy.train(experience)
logging.info('Training completed')
# eval
#logging.info('Start of Eval')
#results.append(cl_strategy.eval(scenario.test_stream))
#logging.info('End of Eval')
#TODO comment those next three lines and uncomment the previous ones if you're not using the separated softmax trick
logging.info('Start of Eval')
results.append(cl_strategy.eval(scenario.test_stream))
logging.info('End of Eval')
# Logging metrics and artifacts into sacred
cf_matrix=results[-1]['ConfusionMatrix_Stream/eval_phase/test_stream']
df_cm = pd.DataFrame(cf_matrix/ np.sum(cf_matrix.numpy(), axis=1)[:, None], index = [i for i in DATASET.labels_names],
columns = [i for i in DATASET.labels_names])
fig, ax = plt.subplots(figsize = (24,14))
sn.heatmap(df_cm, annot=True,ax=ax)
plt.savefig('heatmap.png')#figure size doesnt work
_run.add_artifact('./heatmap.png')
os.remove('./heatmap.png')
## Save the model as an artifact
if(save_model):
torch.save(model.state_dict(),os.path.join('../models/','model.pt'))
_run.add_artifact('../models/model.pt')
os.remove('../models/model.pt')
## The save process in tensorboard in multithreaded, we add this sleep to make sure that the file was saved before accessing it
time.sleep(120)
## Getting the latest file added to ./tb_data
list_of_files = glob.glob('../tb_data/*')
latest_file = max(list_of_files, key=os.path.getctime)
## Saving the tensorboard data in sacred
_run.add_artifact(latest_file)# Add raw tf record file to sacred
_run.add_artifact('../log.txt')
os.remove('../log.txt')
serialized_examples = tf.data.TFRecordDataset(latest_file)
for serialized_example in serialized_examples:
event = event_pb2.Event.FromString(serialized_example.numpy())
for value in event.summary.value:
_run.log_scalar(value.tag, value.simple_value)
# We give the average accuracy as the result. The other metrics can be found in Omniboard
return results[-1]['Top1_Acc_Stream/eval_phase/test_stream/Task000']