この記事は参考文献から不使用コードを削ったり行間を詰めて全体を眺めやすくしたりしているだけで、参考文献のコードと同じです。目次
Traffic データのロードと862変量時系列化
from datasets import load_dataset import pandas as pd import numpy as np from functools import lru_cache from functools import partial from gluonts.dataset.multivariate_grouper import MultivariateGrouper @lru_cache(10_000) def convert_to_pandas_period(date, freq): return pd.Period(date, freq) def transform_start_field(batch, freq): batch['start'] = [convert_to_pandas_period(date, freq) for date in batch['start']] return batch dataset = load_dataset('monash_tsf', 'traffic_hourly') train_dataset = dataset['train'] test_dataset = dataset['test'] num_of_variates = len(train_dataset) # 862 freq = '1H' prediction_length = 48 train_dataset.set_transform(partial(transform_start_field, freq=freq)) test_dataset.set_transform(partial(transform_start_field, freq=freq)) train_grouper = MultivariateGrouper(max_target_dim=num_of_variates) test_grouper = MultivariateGrouper(max_target_dim=num_of_variates, num_test_dates=1) multi_variate_train_dataset = train_grouper(train_dataset) multi_variate_test_dataset = test_grouper(test_dataset)
時間的特徴を付加する関数の用意
from transformers import PretrainedConfig from gluonts.dataset.field_names import FieldName from gluonts.transform import * def create_transformation(freq: str, config: PretrainedConfig) -> Transformation: remove_field_names = [] if config.num_static_real_features == 0: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if config.num_dynamic_real_features == 0: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) if config.num_static_categorical_features == 0: remove_field_names.append(FieldName.FEAT_STATIC_CAT) return Chain( [RemoveFields(field_names=remove_field_names)] + ([AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int)] if config.num_static_categorical_features > 0 else []) + ([AsNumpyArray(field=FieldName.FEAT_STATIC_REAL, expected_ndim=1)] if config.num_static_real_features > 0 else []) + [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1 if config.input_size == 1 else 2), AddObservedValuesIndicator(target_field=FieldName.TARGET, output_field=FieldName.OBSERVED_VALUES), AddTimeFeatures(start_field=FieldName.START, target_field=FieldName.TARGET, output_field=FieldName.FEAT_TIME, time_features=time_features_from_frequency_str(freq), pred_length=config.prediction_length), AddAgeFeature(target_field=FieldName.TARGET, output_field=FieldName.FEAT_AGE, pred_length=config.prediction_length, log_scale=True), VstackFeatures(output_field=FieldName.FEAT_TIME, input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + \ ([FieldName.FEAT_DYNAMIC_REAL] if config.num_dynamic_real_features > 0 else [])), RenameFields(mapping={ FieldName.FEAT_STATIC_CAT: "static_categorical_features", FieldName.FEAT_STATIC_REAL: "static_real_features", FieldName.FEAT_TIME: "time_features", FieldName.TARGET: "values", FieldName.OBSERVED_VALUES: "observed_mask"}), ] )
ウィンドウをサンプリングする関数の用意
from gluonts.transform.sampler import InstanceSampler from typing import Optional def create_instance_splitter( config: PretrainedConfig, mode: str, train_sampler: Optional[InstanceSampler] = None, validation_sampler: Optional[InstanceSampler] = None, ) -> Transformation: instance_sampler = { "train": train_sampler or ExpectedNumInstanceSampler( num_instances=1.0, min_future=config.prediction_length), "validation": validation_sampler or ValidationSplitSampler(min_future=config.prediction_length), "test": TestSplitSampler(), }[mode] return InstanceSplitter( target_field="values", is_pad_field=FieldName.IS_PAD, start_field=FieldName.START, forecast_start_field=FieldName.FORECAST_START, instance_sampler=instance_sampler, past_length=config.context_length + max(config.lags_sequence), future_length=config.prediction_length, time_series_fields=["time_features", "observed_mask"], )
データローダーの用意
import torch from typing import Iterable from gluonts.itertools import Cached, Cyclic from gluonts.dataset.loader import as_stacked_batches def create_train_dataloader( config: PretrainedConfig, freq, data, batch_size: int, num_batches_per_epoch: int, shuffle_buffer_length: Optional[int] = None, cache_data: bool = True, **kwargs, ) -> Iterable: PREDICTION_INPUT_NAMES = [ "past_time_features", "past_values", "past_observed_mask", "future_time_features"] if config.num_static_categorical_features > 0: PREDICTION_INPUT_NAMES.append("static_categorical_features") if config.num_static_real_features > 0: PREDICTION_INPUT_NAMES.append("static_real_features") TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + ["future_values", "future_observed_mask"] transformation = create_transformation(freq, config) transformed_data = transformation.apply(data, is_train=True) if cache_data: transformed_data = Cached(transformed_data) instance_splitter = create_instance_splitter(config, "train") stream = Cyclic(transformed_data).stream() training_instances = instance_splitter.apply(stream, is_train=True) return as_stacked_batches( training_instances, batch_size=batch_size, shuffle_buffer_length=shuffle_buffer_length, field_names=TRAINING_INPUT_NAMES, output_type=torch.tensor, num_batches_per_epoch=num_batches_per_epoch, )
モデルの生成と1バッチへの動作確認
from transformers import InformerConfig, InformerForPrediction from gluonts.time_feature import time_features_from_frequency_str time_features = time_features_from_frequency_str(freq) config = InformerConfig( input_size=num_of_variates, prediction_length=prediction_length, context_length=prediction_length * 2, lags_sequence=[1, 24 * 7], num_time_features=len(time_features) + 1, dropout=0.1, encoder_layers=6, decoder_layers=4, d_model=64, ) model = InformerForPrediction(config) train_dataloader = create_train_dataloader( config=config, freq=freq, data=multi_variate_train_dataset, batch_size=256, num_batches_per_epoch=100, ) batch = next(iter(train_dataloader)) outputs = model( past_values=batch["past_values"], past_time_features=batch["past_time_features"], past_observed_mask=batch["past_observed_mask"], static_categorical_features=batch["static_categorical_features"] if config.num_static_categorical_features > 0 else None, static_real_features=batch["static_real_features"] if config.num_static_real_features > 0 else None, future_values=batch["future_values"], future_time_features=batch["future_time_features"], future_observed_mask=batch["future_observed_mask"], output_hidden_states=True, ) print("Loss:", outputs.loss.item())