Source code for templates.river

import warnings
from typing import Optional, Sequence

import os
import torch

from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.plugins.evaluation import default_evaluator
from avalanche.models.dynamic_modules import MultiTaskModule
from avalanche.models import FeatureExtractorBackbone

from river import metrics


[docs]class RiverTemplate(SupervisedTemplate): def __init__( self, deep_model, online_model, criterion, input_size, output_layer_name=None, train_epochs: int = 1, train_mb_size: int = 1, eval_mb_size: int = 1, device="cpu", plugins: Optional[Sequence["SupervisedPlugin"]] = None, evaluator=default_evaluator(), eval_every=-1, ): if plugins is None: plugins = [] deep_model = deep_model.eval() if output_layer_name is not None: deep_model = FeatureExtractorBackbone( deep_model.to(device), output_layer_name ).eval() super(RiverTemplate, self).__init__( model=deep_model, optimizer=None, criterion=criterion, train_mb_size=train_mb_size, train_epochs=train_epochs, eval_mb_size=eval_mb_size, device=device, plugins=plugins, evaluator=evaluator, eval_every=eval_every, ) self.input_size=input_size self.online_model = online_model self.train_metrics_list=[metrics.ConfusionMatrix(),metrics.ClassificationReport()] self.test_metrics_list=[metrics.ConfusionMatrix(),metrics.ClassificationReport()]
[docs] def tensor_to_dict(self,tensor): res=[] if len(tensor.shape)>1: #then we have data batch_size=tensor.shape[0] for i in range(batch_size): res.append(dict(zip([f'Feature {i}' for i in range(self.input_size)],tensor[i].cpu().tolist()))) else: res=tensor.cpu().tolist() return res
[docs] def forward(self): """Compute the model's output given the current mini-batch.""" self.model.eval() if isinstance(self.model, MultiTaskModule): raise NotImplementedError #feat = self.model(self.mb_x, self.mb_task_id) else: # no task labels feat = self.model(self.mb_x) return feat
[docs] def training_epoch(self, **kwargs): """ Training epoch. :param kwargs: :return: """ for _, self.mbatch in enumerate(self.dataloader): self._unpack_minibatch() self._before_training_iteration(**kwargs) self.loss = 0 # Forward self._before_forward(**kwargs) # compute output on entire minibatch self.mb_output = self.forward() self._after_forward(**kwargs) # Optimization step self._before_update(**kwargs) # process one element at a time for x, y in zip(self.tensor_to_dict(self.mb_output), self.tensor_to_dict(self.mb_y)): #here we update the online model y_pred = self.online_model.predict_one(x) self.online_model.learn_one(x,y) if y_pred is not None: for metric in self.train_metrics_list: metric.update(y,y_pred) self._after_update(**kwargs) self._after_training_iteration(**kwargs)
def _after_training_exp(self, **kwargs): for metric in self.train_metrics_list: print(metric)
[docs] def eval_epoch(self, **kwargs): for self.mbatch in self.dataloader: self._unpack_minibatch() self._before_eval_iteration(**kwargs) self._before_eval_forward(**kwargs) self.mb_output = self.forward() self._after_eval_forward(**kwargs) # process one element at a time for x, y in zip(self.tensor_to_dict(self.mb_output), self.tensor_to_dict(self.mb_y)): #here we update the online model y_pred = self.online_model.predict_one(x) for metric in self.test_metrics_list: metric.update(y,y_pred) self._after_eval_iteration(**kwargs)
def _after_eval(self, **kwargs): for metric in self.test_metrics_list: print(metric)
[docs] def make_optimizer(self): """Empty function. River online models do not need a Pytorch optimizer.""" pass
__all__ = ["RiverTemplate"]