# This code is part of a Qiskit project.
#
# (C) Copyright IBM 2019, 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
""" Pseudo-randomly generated mock stock-market data provider """
from typing import Optional, Union, List
import datetime
import logging
import pandas as pd
import numpy as np
from ._base_data_provider import BaseDataProvider
logger = logging.getLogger(__name__)
[docs]class RandomDataProvider(BaseDataProvider):
"""Pseudo-randomly generated mock stock-market data provider."""
def __init__(
self,
tickers: Optional[Union[str, List[str]]] = None,
start: datetime.datetime = datetime.datetime(2016, 1, 1),
end: datetime.datetime = datetime.datetime(2016, 1, 30),
seed: Optional[int] = None,
) -> None:
"""
Args:
tickers: tickers
start: first data point
end: last data point precedes this date
seed: optional random seed
"""
super().__init__()
tickers = tickers if tickers is not None else ["TICKER1", "TICKER2"]
if isinstance(tickers, list):
self._tickers = tickers
else:
self._tickers = tickers.replace("\n", ";").split(";")
self._n = len(self._tickers)
self._start = start
self._end = end
self._seed = seed
[docs] def run(self) -> None:
"""
Generates data pseudo-randomly, thus enabling get_similarity_matrix
and get_covariance_matrix methods in the base class.
"""
length = (self._end - self._start).days
generator = np.random.default_rng(self._seed)
self._data = []
for _ in self._tickers:
d_f = pd.DataFrame(generator.standard_normal(length)).cumsum() + generator.integers(
1, 101
)
trimmed = np.maximum(d_f[0].values, np.zeros(len(d_f[0].values)))
trimmed_list = trimmed.tolist()
# find index of first 0 element
zero_idx = next((idx for idx, val in enumerate(trimmed_list) if val == 0), -1)
if zero_idx >= 0:
# set to 0 all values after first 0
trimmed_list = [
val if idx < zero_idx else 0 for idx, val in enumerate(trimmed_list)
]
self._data.append(trimmed_list)