From eebbf16436560f8915db8bff356abd918283d664 Mon Sep 17 00:00:00 2001 From: Ojietohamen Samuel Date: Tue, 8 Oct 2024 16:50:20 +0100 Subject: [PATCH 01/12] feat: add basic evaluation functionality --- examples/sb3_agent_example.py | 28 +++++++++++++++++++++++++++- trade_flow/agents/base.py | 2 ++ trade_flow/agents/sb3_agent.py | 8 ++++---- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/examples/sb3_agent_example.py b/examples/sb3_agent_example.py index 8b79b60..9740293 100644 --- a/examples/sb3_agent_example.py +++ b/examples/sb3_agent_example.py @@ -4,6 +4,8 @@ # import pandas_ta as ta # import trade_flow +from trade_flow.agents.base import Agent +from trade_flow.environments.generic.environment import TradingEnvironment from trade_flow.feed import Stream, DataFeed, NameSpace, Coinbase_BTCUSD_1h, Coinbase_BTCUSD_d from trade_flow.environments.default.oms.exchanges import Exchange from trade_flow.environments.default.oms.execution.simulated import execute_order @@ -101,7 +103,7 @@ def get_env(df: pd.DataFrame = Coinbase_BTCUSD_d): with NameSpace("coinbase"): streams = [ - Stream.source(selected_dataset[c].tolist(), dtype="float").rename(c) + Stream.source(selected_dataset[c].tolist(), dtype=selected_dataset[c].dtype).rename(c) for c in selected_dataset.columns ] @@ -192,6 +194,26 @@ def get_env_with_multiple_renderers(df: pd.DataFrame = Coinbase_BTCUSD_d): return env +def evaluate_model(env: TradingEnvironment, agent: Agent): + """ + Evaluate the model + """ + obs = env.reset() + for i in range(100): + print(obs) + # action, _states = agent.predict(obs[0]) + + # Take a random action + action = env.action_space.sample() + + obs, reward, done, _, info = env.step(action) + env.render() + + if done: + print(f"Episode finished after {i + 1} steps") + break + + if __name__ == "__main__": env = get_env(Coinbase_BTCUSD_1h) # df = Coinbase_BTCUSD_d | Coinbase_BTCUSD_1h @@ -210,6 +232,8 @@ def get_env_with_multiple_renderers(df: pd.DataFrame = Coinbase_BTCUSD_d): performance.plot() + # evaluate_model(env, agent) + print("\n\n---------Environment with Multiple Renderers-------------\n\n") env_multiple = get_env_with_multiple_renderers( @@ -233,3 +257,5 @@ def get_env_with_multiple_renderers(df: pd.DataFrame = Coinbase_BTCUSD_d): performance.plot() performance.net_worth.plot() + + # evaluate_model(env, agent) diff --git a/trade_flow/agents/base.py b/trade_flow/agents/base.py index a39b879..447682b 100644 --- a/trade_flow/agents/base.py +++ b/trade_flow/agents/base.py @@ -81,6 +81,7 @@ def predict(self, state: Any) -> Dict[str, Any]: """ pass + @staticmethod def save(self, checkpoint_path: str) -> None: """ Save the agent's model and state to a file. @@ -96,6 +97,7 @@ def save(self, checkpoint_path: str) -> None: except Exception as e: raise RuntimeError(f"Failed to save the model: {e}") + @staticmethod def load(self, checkpoint_path: str) -> None: """ Load the agent's model and state from a file. diff --git a/trade_flow/agents/sb3_agent.py b/trade_flow/agents/sb3_agent.py index 3ec3548..7f421c8 100644 --- a/trade_flow/agents/sb3_agent.py +++ b/trade_flow/agents/sb3_agent.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, List +from typing import Any, Dict, Optional, List, Tuple import numpy as np from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise @@ -145,7 +145,7 @@ def evaluate(self, env: Any) -> Dict[str, Any]: """ raise NotImplementedError - def predict(self, state: Any) -> Dict[str, Any]: + def predict(self, state: Any, **kwargs) -> Tuple[Any, Any]: """ Predict the next action based on the current state. @@ -162,7 +162,7 @@ def predict(self, state: Any) -> Dict[str, Any]: state = np.array(state) try: - action, _states = self.model.predict(state, deterministic=True) - return {"action": action} + action, _states = self.model.predict(state, deterministic=True, **kwargs) + return action, _states except Exception as e: raise RuntimeError(f"Error during prediction: {str(e)}") From 1a0769ebcb31e69d80937178429a3c2f6231a3dd Mon Sep 17 00:00:00 2001 From: Ojietohamen Samuel Date: Tue, 8 Oct 2024 16:51:03 +0100 Subject: [PATCH 02/12] feat: add environment creation utility --- examples/create_environment_example.py | 201 +++++++++++++++++++++++++ trade_flow/environments/utils.py | 123 +++++++-------- 2 files changed, 259 insertions(+), 65 deletions(-) create mode 100644 examples/create_environment_example.py diff --git a/examples/create_environment_example.py b/examples/create_environment_example.py new file mode 100644 index 0000000..3d081ca --- /dev/null +++ b/examples/create_environment_example.py @@ -0,0 +1,201 @@ +from typing import Optional, Tuple, Union, Dict +import pandas as pd +from sklearn.model_selection import train_test_split +from trade_flow.agents.base import Agent +from trade_flow.environments.generic.environment import TradingEnvironment +from trade_flow.environments.utils import create_env_from_dataframe +from trade_flow.feed import Stream, Coinbase_BTCUSD_1h, Coinbase_BTCUSD_d +from trade_flow.environments.default.oms.exchanges import Exchange +from trade_flow.environments.default.oms.execution.simulated import execute_order +from trade_flow.environments.default.oms.instruments import USD, BTC +from trade_flow.environments.default.oms.wallet import Wallet +from trade_flow.environments.default.oms.portfolio import Portfolio +import trade_flow.environments.default as default +from trade_flow.agents import SB3Agent + + +def encode_symbols(data: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, int]]: + """ + Encodes the currency symbols in the data using label encoding. + + Parameters: + ---------- + data (pd.DataFrame): The DataFrame containing the data. + + Returns: + ------- + Tuple[pd.DataFrame, Dict[str, int]]: The encoded DataFrame and the mapping dictionary. + """ + symbols = data["symbol"].unique() + vocabulary = {pair: i for i, pair in enumerate(symbols)} + data["symbol_encoded"] = data["symbol"].apply(lambda pair: vocabulary[pair]) + return data, vocabulary + + +def create_portfolio(price_history: pd.DataFrame) -> Portfolio: + """ + Creates a default portfolio with initial USD and BTC balance for a Coinbase exchange. + + Parameters: + ---------- + price_history (pd.DataFrame): The DataFrame containing the price history. + + Returns: + ------- + Portfolio: A trading portfolio containing USD and BTC. + """ + coinbase = Exchange("coinbase", service=execute_order)( + Stream.source(price_history["close"].tolist(), dtype=price_history["close"].dtype).rename( + "USD-BTC" + ) # TODO: fix Exception: No stream satisfies selector condition. for `multiple stream sources` + ) + return Portfolio( + USD, + [ + Wallet(coinbase, 1000 * USD), + Wallet(coinbase, 1 * BTC), + ], + ) + + +def create_environment( + df: pd.DataFrame = Coinbase_BTCUSD_d, + split: bool = False, + test_size: float = 0.2, + seed: int = 42, + shuffle: bool = False, +) -> Union[TradingEnvironment, Tuple[TradingEnvironment, TradingEnvironment]]: + """ + Creates a trading environment using the provided dataset and configuration. + + Parameters: + ----------- + df (pd.DataFrame): Input dataset containing market data. + split (bool): Whether to split the dataset into train and test sets. + test_size (float): Proportion of the dataset for testing. + seed (int): Random seed for reproducibility. + shuffle (bool): Whether to shuffle the data before splitting. + + Returns: + ------- + Union[TradingEnvironment, Tuple[TradingEnvironment, TradingEnvironment]]: + Single or tuple of trading environments based on the split parameter. + """ + + dataset = df.reset_index() + + # Preprocess and encode symbols + dataset_encoded, vocabulary = encode_symbols(dataset) + print(f"Vocabulary: {vocabulary}") + + # Create a portfolio and action scheme + portfolio = create_portfolio(dataset_encoded[["close"]]) + action_scheme = default.actions.ManagedRiskOrders() + action_scheme.portfolio = portfolio + + # Create a reward scheme + reward_scheme = default.rewards.RiskAdjustedReturns() + + # Split dataset if required + if split: + train_data, test_data = train_test_split( + dataset_encoded, + test_size=test_size, + random_state=seed, + shuffle=shuffle, + ) + + print(train_data) + + portfolio = create_portfolio(train_data[["close"]]) + action_scheme.portfolio = portfolio + train_env = create_env_from_dataframe( + name="coinbase_train", + dataset=train_data, + action_scheme=action_scheme, + reward_scheme=reward_scheme, + window_size=5, + portfolio=portfolio, + ) + + # portfolio = create_portfolio(test_data[["date", "open", "high", "low", "close", "volume"]]) + # action_scheme.portfolio = portfolio + test_env = create_env_from_dataframe( + name="coinbase_test", + dataset=test_data, + action_scheme=action_scheme, + reward_scheme=reward_scheme, + window_size=5, + portfolio=portfolio, + ) + return train_env, test_env + + # Create a single environment if no split + return create_env_from_dataframe( + name="coinbase_env", + dataset=dataset_encoded[["symbol_encoded", "volume_btc"]], + action_scheme=action_scheme, + reward_scheme=reward_scheme, + window_size=5, + ) + + +def evaluate_model(env: TradingEnvironment, agent: Agent, n_steps: int = 100): + """ + Evaluate the trained model in a given trading environment. + + Args: + env (TradingEnvironment): The trading environment to evaluate. + agent (Agent): The agent to evaluate in the environment. + n_steps (int): Number of steps to run in the evaluation. + """ + obs = env.reset() + for step in range(n_steps): + print(f"Observation at step {step}: {obs}") + + # Take a random action for evaluation purposes (use agent's action for real evaluations) + action = env.action_space.sample() + obs, reward, done, _, _ = env.step(action) + env.render() + + if done: + print(f"Episode finished after {step + 1} steps.") + break + + +def train_and_evaluate( + train_env: TradingEnvironment, + test_env: TradingEnvironment, + n_episodes: int = 2, + n_steps: int = 1000, +): + """ + Train an agent on the training environment and evaluate on the test environment. + + Args: + train_env (TradingEnvironment): Training environment. + test_env (TradingEnvironment): Testing environment. + n_episodes (int): Number of episodes to train the agent. + n_steps (int): Number of steps per episode. + """ + agent = SB3Agent(train_env) + agent.get_model("a2c", {"policy": "MlpPolicy"}) + print(f"Agent: {agent}") + + agent.train(n_episodes=n_episodes, n_steps=n_steps, progress_bar=True) + performance = pd.DataFrame.from_dict( + train_env.action_scheme.portfolio.performance, orient="index" + ) + print("Training performance: \n", performance) + performance.plot() + + print("Evaluating on test environment...") + evaluate_model(test_env, agent) + + +if __name__ == "__main__": + # Create environments for training and testing + train_env, test_env = create_environment(Coinbase_BTCUSD_1h, split=True) + + # Train the agent and evaluate performance + train_and_evaluate(train_env, test_env) diff --git a/trade_flow/environments/utils.py b/trade_flow/environments/utils.py index f2ebd27..dd4f5eb 100644 --- a/trade_flow/environments/utils.py +++ b/trade_flow/environments/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import os import pkgutil import toml @@ -12,17 +12,28 @@ from trade_flow.environments.generic.components.observer import Observer from trade_flow.environments.generic.components.renderer import AggregateRenderer from trade_flow.environments.generic.components.reward_scheme import RewardScheme +from trade_flow.environments.generic.components.stopper import Stopper from trade_flow.environments.generic.environment import TradingEnvironment from trade_flow.feed import Stream, DataFeed, NameSpace def get_available_environments() -> list[tuple]: """ - List available environments in Trade Flow, using environment.toml files for information. + List available environments in Trade Flow using environment metadata files. - Returns: - A list of tuples, where each tuple contains the environment name, version, and description. + This function scans through the modules in the `trade_flow.environments` path, + reads the `metadata.toml` files, and extracts relevant details like environment name, + version, and description based on the expected schema. + + Returns + ------- + list[tuple] + A list of tuples, where each tuple contains: + - `environment_name`: Name of the environment. + - `environment_description`: Short description of the environment. + - `environment_version`: Version of the environment. """ + schema = { "type": "object", "required": ["environment"], @@ -79,49 +90,61 @@ def get_available_environments() -> list[tuple]: def create_env_from_dataframe( name: str, dataset: pd.DataFrame, - portfolio: Component, action_scheme: ActionScheme, reward_scheme: RewardScheme, window_size: int = 1, min_periods: Optional[int] = 100, random_start_pct: float = 0.00, observer: Optional[Observer] = None, + stopper: Optional[Stopper] = None, **kwargs, ) -> TradingEnvironment: - """Creates a `TradingEnvironment` from a dataframe. + """ + Creates a `TradingEnvironment` instance from a given DataFrame. + + This function initializes a trading environment using a provided dataset, portfolio, + action scheme, and other configurable components. Parameters ---------- - name : `str` - The name to be used by the environment. - dataset : `pd.DataFrame` - The dataset to be used by the environment. - portfolio : `Component` - The portfolio component to be used by the environment. - action_scheme : `ActionScheme` - The action scheme for computing actions at every step of an episode. - reward_scheme : `RewardScheme` - The reward scheme for computing rewards at every step of an episode. - window_size : int - The size of the look back window to use for the observation space. - min_periods : int, optional - The minimum number of steps to warm up the `feed`. - random_start_pct : float, optional - Whether to randomize the starting point within the environment at each - observer reset, starting in the first X percentage of the sample - **kwargs : keyword arguments - Extra keyword arguments needed to build the environment. + name : str + The name of the environment instance. + dataset : pd.DataFrame + The input dataset containing historical market data. + action_scheme : ActionScheme + The scheme for possible trading actions. + reward_scheme : RewardScheme + The reward scheme used to compute rewards. + window_size : int, default=1 + The observation window size for the environment. + min_periods : Optional[int], default=100 + Minimum number of observations required before trading. + random_start_pct : float, default=0.00 + Percentage of the dataset for a randomized starting point. + observer : Optional[Observer], default=None + Custom observer, if not provided, the default is used. + stopper : Optional[Stopper], default=None + Custom stopper. If not provided, a `MaxLossStopper` is used with a loss threshold. + **kwargs : Additional keyword arguments for other components. Returns ------- - `TradingEnvironment` + TradingEnvironment + A trading environment configured with the specified dataset, action scheme, + reward scheme, observer, and other components. + + Raises + ------ + AttributeError + If the specified action scheme does not have a `portfolio` attribute. - The default trading environment. """ + # Check for the portfolio in the action scheme + if not hasattr(action_scheme, "portfolio"): + raise AttributeError("action scheme no attribute named portfolio.") - # Create a namespace for the environment + # If split is not enabled, create a single environment with NameSpace(name): - # Create streams from the dataset columns streams = [ Stream.source(dataset[c].tolist(), dtype=dataset[c].dtype).rename(c) for c in dataset.columns @@ -130,23 +153,21 @@ def create_env_from_dataframe( # Create a data feed from the streams feed = DataFeed(streams) - # Set the portfolio in the action scheme - action_scheme.portfolio = portfolio - - # Create an observer with the portfolio, feed, and other parameters + # Create the observer if not provided if observer is None: observer = observers.TradeFlowObserver( - portfolio=portfolio, + portfolio=kwargs.get("portfolio", None), feed=feed, renderer_feed=kwargs.get("renderer_feed", None), window_size=window_size, min_periods=min_periods, ) - # Create a stopper with the maximum allowed loss - stopper = stoppers.MaxLossStopper(max_allowed_loss=kwargs.get("max_allowed_loss", 0.5)) + # Create the stopper if not provided + if stopper is None: + stopper = stoppers.MaxLossStopper(max_allowed_loss=kwargs.get("max_allowed_loss", 0.5)) - # Create a renderer based on the provided renderer options + # Create the renderer based on the provided renderer options renderer_list = kwargs.get("renderer", renderers.EmptyRenderer()) if isinstance(renderer_list, list): @@ -175,31 +196,3 @@ def create_env_from_dataframe( ) return env - - -def train_test_split_env( - dataset: pd.DataFrame, test_size: float = 0.2, seed: int = 42 -) -> Tuple[TradingEnvironment, TradingEnvironment]: - """ - Splits the environment into training and testing environments by - splitting the underlying data it uses. - - This function assumes the TradingEnvironment has a method to reset - with different data subsets. - - Args: - dataset: The DataFrame to split. - test_size: The proportion of data to allocate to the testing environment (default: 0.2). - seed: Random seed for splitting the data (default: 42). - - Returns: - Tuple[TradingEnvironment, TradingEnvironment]: The training and testing environments. - """ - # Split the data using train_test_split - train_data, test_data = train_test_split(dataset, test_size=test_size, random_state=seed) - - # Create new environments with the split data (assuming a reset_with_data method) - train_env = create_env_from_dataframe(dataset=train_data) - test_env = create_env_from_dataframe(dataset=test_data) - - return train_env, test_env From 8fe03f665b8d5eecdf21c7190f98264a798ba1d9 Mon Sep 17 00:00:00 2001 From: Ojietohamen Samuel Date: Tue, 8 Oct 2024 16:52:25 +0100 Subject: [PATCH 03/12] feat: Add RL-based backtesting environment (need fix) --- .../environments/nt_backtest/__init__.py | 4 +- .../nt_backtest/{model.py => agent.py} | 115 +-- .../environments/nt_backtest/backtest.py | 141 ++- trade_flow/environments/nt_backtest/models.py | 33 + .../environments/nt_backtest/strategy.py | 63 +- .../nt_backtest/sub_training_env/__init__.py | 104 +++ .../nt_backtest/sub_training_env/actions.py | 433 +++++++++ .../nt_backtest/sub_training_env/informers.py | 10 + .../nt_backtest/sub_training_env/observers.py | 461 ++++++++++ .../nt_backtest/sub_training_env/renderers.py | 834 ++++++++++++++++++ .../nt_backtest/sub_training_env/rewards.py | 245 +++++ .../nt_backtest/sub_training_env/stoppers.py | 32 + trade_flow/environments/nt_backtest/utils.py | 15 +- 13 files changed, 2350 insertions(+), 140 deletions(-) rename trade_flow/environments/nt_backtest/{model.py => agent.py} (51%) create mode 100644 trade_flow/environments/nt_backtest/models.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/__init__.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/actions.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/informers.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/observers.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/renderers.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/rewards.py create mode 100644 trade_flow/environments/nt_backtest/sub_training_env/stoppers.py diff --git a/trade_flow/environments/nt_backtest/__init__.py b/trade_flow/environments/nt_backtest/__init__.py index 73fa0d0..1b5d579 100644 --- a/trade_flow/environments/nt_backtest/__init__.py +++ b/trade_flow/environments/nt_backtest/__init__.py @@ -1,6 +1,6 @@ """ -Provides an environment to train nautilus trader-based agents. +Provides an environment to train/test nautilus trader-based agents. """ from .node import * -from .model import * +from .agent import * diff --git a/trade_flow/environments/nt_backtest/model.py b/trade_flow/environments/nt_backtest/agent.py similarity index 51% rename from trade_flow/environments/nt_backtest/model.py rename to trade_flow/environments/nt_backtest/agent.py index 9ef24b6..1d30a26 100644 --- a/trade_flow/environments/nt_backtest/model.py +++ b/trade_flow/environments/nt_backtest/agent.py @@ -7,56 +7,24 @@ from nautilus_trader.core.data import Data from nautilus_trader.common.actor import Actor, ActorConfig -from nautilus_trader.model.data import DataType -from nautilus_trader.model.data import Bar, BarSpecification +from nautilus_trader.model.data import DataType, Bar, BarSpecification from nautilus_trader.model.identifiers import InstrumentId from nautilus_trader.common.enums import LogColor from nautilus_trader.core.data import Data from nautilus_trader.core.datetime import secs_to_nanos, unix_nanos_to_dt -from nautilus_trader.model.identifiers import TraderId # gym needed for the environment - -from sklearn.metrics import r2_score -from sklearn.linear_model import LinearRegression - +from trade_flow.agents.base import Agent +from trade_flow.agents.sb3_agent import SB3Agent +from trade_flow.environments.nt_backtest.models import Action, ModelUpdate from trade_flow.environments.nt_backtest.utils import bars_to_dataframe, make_bar_type +from trade_flow.environments.utils import create_env_from_dataframe + -# from nautilus_trader.persistence.catalog import ParquetDataCatalog as DataCatalog - - -class ModelUpdate(Data): - def __init__( - self, - model: LinearRegression, - hedge_ratio: float, - std_prediction: float, - ts_init: int, - ): - super().__init__(ts_init=ts_init, ts_event=ts_init) - self.model = model - self.hedge_ratio = hedge_ratio - self.std_prediction = std_prediction - - -class Prediction(Data): - def __init__( - self, - instrument_id: str, - prediction: float, - ts_init: int, - ): - super().__init__(ts_init=ts_init, ts_event=ts_init) - self.instrument_id = instrument_id - self.prediction = prediction - - -# -# The model Interface here would be the Agent Interface instead of LinearRegression -# class DRLAgentConfig(ActorConfig): - source_symbol: str - target_symbol: str + symbol: str + train_n_episodes: int = 2 + train_n_steps: int = 600 bar_spec: str = "10-SECOND-LAST" min_model_timedelta: str = "1D" @@ -65,10 +33,11 @@ class DRLAgent(Actor): def __init__(self, config: DRLAgentConfig): super().__init__(config=config) - self.source_id = InstrumentId.from_str(config.source_symbol) - self.target_id = InstrumentId.from_str(config.target_symbol) + self.instrument_id = InstrumentId.from_str(config.symbol) self.bar_spec = BarSpecification.from_str(self.config.bar_spec) - self.model: Optional[LinearRegression] = None + self.n_episodes: int = self.config.train_n_episodes + self.n_steps: int = self.config.train_n_steps + self.model: Optional[SB3Agent] = None # self.config.agent self.hedge_ratio: Optional[float] = None self._min_model_timedelta = secs_to_nanos( pd.Timedelta(self.config.min_model_timedelta).total_seconds() @@ -77,12 +46,10 @@ def __init__(self, config: DRLAgentConfig): def on_start(self): # Set instruments - self.left = self.cache.instrument(self.source_id) - self.right = self.cache.instrument(self.target_id) + self.instrument_cache = self.cache.instrument(self.instrument_id) # Subscribe to bars - self.subscribe_bars(make_bar_type(instrument_id=self.source_id, bar_spec=self.bar_spec)) - self.subscribe_bars(make_bar_type(instrument_id=self.target_id, bar_spec=self.bar_spec)) + self.subscribe_bars(make_bar_type(instrument_id=self.instrument_id, bar_spec=self.bar_spec)) def on_bar(self, bar: Bar): self._check_model_fit(bar) @@ -90,7 +57,7 @@ def on_bar(self, bar: Bar): @property def data_length_valid(self) -> bool: - return self._check_first_tick(self.source_id) and self._check_first_tick(self.target_id) + return self._check_first_tick(self.instrument_id) @property def has_fit_model_today(self): @@ -115,35 +82,37 @@ def _check_model_fit(self, bar: Bar): # Generate a dataframe from cached bar data df = bars_to_dataframe( - source_id=self.source_id.value, - source_bars=self.cache.bars( - bar_type=make_bar_type(self.source_id, bar_spec=self.bar_spec) - ), - target_id=self.target_id.value, - target_bars=self.cache.bars( - bar_type=make_bar_type(self.target_id, bar_spec=self.bar_spec) + instrument_id=self.instrument_id.value, + instrument_bars=self.cache.bars( + bar_type=make_bar_type(self.instrument_id, bar_spec=self.bar_spec) ), ) - # Format the arrays for scikit-learn - X = df.loc[:, self.source_id.value].astype(float).values.reshape(-1, 1) - Y = df.loc[:, self.target_id.value].astype(float).values.reshape(-1, 1) - # Fit a model - self.model = LinearRegression(fit_intercept=False) - self.model.fit(X, Y) + env = create_env_from_dataframe(dataset=df) + self.model = SB3Agent(env) + self.model.get_model("dqn", {"policy": "MlpPolicy"}) + self.model.train( + n_episodes=self.n_episodes, + n_steps=self.n_steps, + progress_bar=True, + ) + performance = pd.DataFrame.from_dict( + env.action_scheme.portfolio.performance, orient="index" + ) + print(performance) self.log.info( - f"Fit model @ {unix_nanos_to_dt(bar.ts_init)}, r2: {r2_score(Y, self.model.predict(X))}", + f"Fit model @ {unix_nanos_to_dt(bar.ts_init)}, performance: {performance}", color=LogColor.BLUE, ) self._last_model = unix_nanos_to_dt(bar.ts_init) # Record std dev of predictions (used for scaling our order price) - pred = self.model.predict(X) + pred = self.model.predict(state) # TODO: a utility function to create state errors = pred - Y std_pred = errors.std() - # The model slope is our hedge ratio (the ratio of source + # The model slope is our hedge ratio (the ratio of self.hedge_ratio = float(self.model.coef_[0][0]) self.log.info(f"Computed hedge_ratio={self.hedge_ratio:0.4f}", color=LogColor.BLUE) @@ -155,17 +124,17 @@ def _check_model_fit(self, bar: Bar): ts_init=bar.ts_init, ) self.publish_data( - data_type=DataType(ModelUpdate, metadata={"instrument_id": self.target_id.value}), + data_type=DataType(ModelUpdate, metadata={"instrument_id": self.instrument_id.value}), data=model_update, ) def _predict(self, bar: Bar): - if self.model is not None and bar.bar_type.instrument_id == self.source_id: - pred = self.model.predict([[bar.close]])[0][0] - prediction = Prediction( - instrument_id=self.target_id, prediction=pred, ts_init=bar.ts_init - ) + if self.model is not None and bar.bar_type.instrument_id == self.instrument_id: + action, _ = self.model.predict( + [[bar.close]] + ) # TODO: a utility function to create state + action = Action(instrument_id=self.instrument_id, action=action, ts_init=bar.ts_init) self.publish_data( - data_type=DataType(Prediction, metadata={"instrument_id": self.target_id.value}), - data=prediction, + data_type=DataType(Action, metadata={"instrument_id": self.instrument_id.value}), + data=action, ) diff --git a/trade_flow/environments/nt_backtest/backtest.py b/trade_flow/environments/nt_backtest/backtest.py index 0e75237..3bac8ef 100644 --- a/trade_flow/environments/nt_backtest/backtest.py +++ b/trade_flow/environments/nt_backtest/backtest.py @@ -1,5 +1,9 @@ import pathlib -from typing import Tuple +import shutil +import pandas as pd +import yfinance as yf +from pathlib import Path +from typing import Optional, Tuple from nautilus_trader.backtest.node import BacktestNode from nautilus_trader.config import ( @@ -13,16 +17,48 @@ RiskEngineConfig, StreamingConfig, ) -from nautilus_trader.model.data import Bar +from nautilus_trader.model.data import Bar, BarType, BarSpecification from nautilus_trader.model.identifiers import InstrumentId -from nautilus_trader.persistence.catalog import ParquetDataCatalog # as DataCatalog +from nautilus_trader.test_kit.providers import CSVTickDataLoader, TestInstrumentProvider +from nautilus_trader.persistence.wranglers import QuoteTickDataWrangler, BarDataWrangler +from nautilus_trader.persistence.catalog import ParquetDataCatalog + +# from trade_flow.agents.sb3_agent import SB3Agent # CATALOG = DataCatalog(str(pathlib.Path(__file__).parent.joinpath("catalog"))) +CATALOG_PATH = Path.cwd() / "catalog" + + +def fetch_data( + instrument_name: str, + datetime_format: str = "mixed", + interval: str = "1d", + start: str = "2020-01-01", + end: str = "2021-01-01,", + venue: str = "NASDAQ", +): + data: Optional[pd.DataFrame] = yf.download( + instrument_name, interval=interval, start=start, end=end + ) + data.index = pd.to_datetime(data.index, format=datetime_format) + data.drop(columns=["Adj Close"], axis=1, inplace=True) + data.index.set_names("timestamp", inplace=True) + + # Process bars using a wrangler + INSTRUMENT = TestInstrumentProvider.equity(symbol=instrument_name.upper(), venue=venue.upper()) + wrangler = BarDataWrangler( + bar_type=BarType.from_str(f"{instrument_name.upper()}.{venue}-1-DAY-LAST-EXTERNAL"), + instrument=INSTRUMENT, + ) + + bars = wrangler.process(data) + + return INSTRUMENT, bars def main( - instrument_ids: Tuple[str, str], + instrument_id: str, catalog: ParquetDataCatalog, notional_trade_size_usd: int = 10_000, start_time: str = None, @@ -32,23 +68,21 @@ def main( persistence: bool = False, **strategy_kwargs, ): - # Create model prediction actor - prediction = ImportableActorConfig( - actor_path="model:DRLAgentActor", - config_path="model:DRLAgentConfig", + # Create agent model actor + agent = ImportableActorConfig( + actor_path="agent:DRLAgent", + config_path="agent:DRLAgentConfig", config=dict( - source_symbol=instrument_ids[0], - target_symbol=instrument_ids[1], + symbol=instrument_id, ), ) # Create strategy strategy = ImportableStrategyConfig( - strategy_path="strategy:PairTrader", - config_path="strategy:PairTraderConfig", + strategy_path="strategy:DRLAgentStrategy", + config_path="strategy:DRLAgentStrategyConfig", config=dict( - source_symbol=instrument_ids[0], - target_symbol=instrument_ids[1], + symbol=instrument_id, notional_trade_size_usd=notional_trade_size_usd, **strategy_kwargs, ), @@ -63,18 +97,18 @@ def main( streaming=StreamingConfig(catalog_path=str(catalog.path)) if persistence else None, risk_engine=RiskEngineConfig(max_order_submit_rate="1000/00:00:01"), # type: ignore strategies=[strategy], - actors=[prediction], + actors=[agent], ) venues = [ BacktestVenueConfig( - name="NASDAQ", + name="BINANCE", oms_type="NETTING", account_type="CASH", base_currency="USD", starting_balances=["1_000_000 USD"], ) ] - print("instrument_ids => ", instrument_ids) + data = [ BacktestDataConfig( data_cls=Bar.fully_qualified_name(), @@ -85,7 +119,7 @@ def main( start_time=start_time, end_time=end_time, ) - for instrument_id in instrument_ids + # for instrument_id in instrument_ids ] run_config = BacktestRunConfig(engine=engine, venues=venues, data=data) @@ -95,18 +129,65 @@ def main( return node.run() -# # typer.run(main) -# # lr_catalog = LR_MODEL_DATA_CATALOG +if __name__ == "__main__": + smh_instrument, smh_bars = fetch_data( + "SMH", + interval="1d", + datetime_format="%Y-%m-%d %H:%M:%S.%f", + start="2020-01-01", + end="2021-01-01", + ) + print(smh_instrument, smh_bars) + + soxx_instrument, soxx_bars = fetch_data( + "SOXX", + interval="1d", + datetime_format="%Y-%m-%d %H:%M:%S.%f", + start="2020-01-01", + end="2021-01-01", + ) + print(soxx_instrument, soxx_bars) + + # Clear if it already exists, then create fresh + if CATALOG_PATH.exists(): + shutil.rmtree(CATALOG_PATH) + CATALOG_PATH.mkdir(parents=True) + + # Create a catalog instance + nautilus_talks_catalog = ParquetDataCatalog(CATALOG_PATH) + + # Write instrument to the catalog + nautilus_talks_catalog.write_data([smh_instrument, soxx_instrument]) + + # Write bars to catalog + nautilus_talks_catalog.write_data(smh_bars) + nautilus_talks_catalog.write_data(soxx_bars) -# assert len(nautilus_talks_catalog.instruments())>0, "Couldn't load instruments, have you run `poetry run inv extract-catalog`?" + nautilus_talks_catalog_path = str(pathlib.Path.cwd().joinpath("catalog")) + print(nautilus_talks_catalog_path) -# [result] = main( -# catalog=nautilus_talks_catalog, -# instrument_ids=("SMH.NASDAQ", "SOXX.NASDAQ"), -# # instrument_ids=("EURUSD.SIM"), -# log_level="INFO", -# persistence=False, -# end_time="2020-06-01", -# ) + nautilus_talks_catalog = ParquetDataCatalog(nautilus_talks_catalog_path) + + print(nautilus_talks_catalog.instruments()) + + # typer.run(main) + # lr_catalog = LR_MODEL_DATA_CATALOG + + assert ( + len(nautilus_talks_catalog.instruments()) > 0 + ), "Couldn't load instruments, have you run `poetry run inv extract-catalog`?" + + results = main( + catalog=nautilus_talks_catalog, + instrument_id="BTCUSD.BINANCE", + # instrument_ids=("SMH.NASDAQ", "SOXX.NASDAQ"), + # instrument_ids=("EURUSD.SIM"), + log_level="INFO", + persistence=False, + end_time="2020-06-01", + ) -# print(result.instance_id) + print("\n\n") + print(results) + if len(results) > 0: + print(results[0].instance_id) diff --git a/trade_flow/environments/nt_backtest/models.py b/trade_flow/environments/nt_backtest/models.py new file mode 100644 index 0000000..1398ff1 --- /dev/null +++ b/trade_flow/environments/nt_backtest/models.py @@ -0,0 +1,33 @@ +from typing import Tuple +from nautilus_trader.core.data import Data +from trade_flow.agents.sb3_agent import SB3Agent + + +class ModelUpdate(Data): + def __init__( + self, + model: SB3Agent, + hedge_ratio: float, + std_action: Tuple, + ts_init: int, + ): + super().__init__(ts_init=ts_init, ts_event=ts_init) + self.model = model + self.hedge_ratio = hedge_ratio + self.std_action = std_action + + +class Action(Data): + def __init__( + self, + instrument_id: str, + action: Tuple, + ts_init: int, + ): + super().__init__(ts_init=ts_init, ts_event=ts_init) + self.instrument_id = instrument_id + self.action = action + + +class RepeatedEventComplete(Exception): + pass diff --git a/trade_flow/environments/nt_backtest/strategy.py b/trade_flow/environments/nt_backtest/strategy.py index 66303b5..e26efbe 100644 --- a/trade_flow/environments/nt_backtest/strategy.py +++ b/trade_flow/environments/nt_backtest/strategy.py @@ -3,14 +3,12 @@ from typing import Optional import pandas as pd -from trade_flow.environments.nt_backtest.model import ModelUpdate, Prediction from nautilus_trader.common.enums import LogColor from nautilus_trader.config import StrategyConfig from nautilus_trader.core.data import Data -from nautilus_trader.core.datetime import unix_nanos_to_dt from nautilus_trader.core.message import Event -from nautilus_trader.model.data.bar import Bar, BarSpecification -from nautilus_trader.model.data.base import DataType +from nautilus_trader.core.datetime import unix_nanos_to_dt +from nautilus_trader.model.data import Bar, BarSpecification, DataType from nautilus_trader.model.enums import OrderSide, PositionSide, TimeInForce from nautilus_trader.model.events.position import ( PositionChanged, @@ -23,16 +21,13 @@ from nautilus_trader.model.position import Position from nautilus_trader.trading.strategy import Strategy from nautilus_trader.model.functions import order_side_to_str +from trade_flow.environments.nt_backtest.agent import ModelUpdate, Action +from trade_flow.environments.nt_backtest.models import RepeatedEventComplete from trade_flow.environments.nt_backtest.utils import human_readable_duration, make_bar_type -class RepeatedEventComplete(Exception): - pass - - -class PairTraderConfig(StrategyConfig): - source_symbol: str - target_symbol: str +class DRLAgentStrategyConfig(StrategyConfig): + symbol: str notional_trade_size_usd: int = 10_000 min_model_timedelta: datetime.timedelta = datetime.timedelta(days=1) trade_width_std_dev: float = 2.5 @@ -40,11 +35,10 @@ class PairTraderConfig(StrategyConfig): ib_long_short_margin_requirement = (0.25 + 0.17) / 2.0 -class PairTrader(Strategy): - def __init__(self, config: PairTraderConfig): +class DRLAgentStrategy(Strategy): + def __init__(self, config: DRLAgentStrategyConfig): super().__init__(config=config) - self.source_id = InstrumentId.from_str(config.source_symbol) - self.target_id = InstrumentId.from_str(config.target_symbol) + self.instrument_id = InstrumentId.from_str(config.symbol) self.model: Optional[ModelUpdate] = None self.hedge_ratio: Optional[float] = None self.std_pred: Optional[float] = None @@ -57,19 +51,17 @@ def __init__(self, config: PairTraderConfig): def on_start(self): # Set instruments - self.source = self.cache.instrument(self.source_id) - self.target = self.cache.instrument(self.target_id) + self.instrument = self.cache.instrument(self.instrument_id) # Subscribe to bars - self.subscribe_bars(make_bar_type(instrument_id=self.source_id, bar_spec=self.bar_spec)) - self.subscribe_bars(make_bar_type(instrument_id=self.target_id, bar_spec=self.bar_spec)) + self.subscribe_bars(make_bar_type(instrument_id=self.instrument_id, bar_spec=self.bar_spec)) # Subscribe to model and predictions self.subscribe_data( - data_type=DataType(ModelUpdate, metadata={"instrument_id": self.target_id.value}) + data_type=DataType(ModelUpdate, metadata={"instrument_id": self.instrument_id.value}) ) self.subscribe_data( - data_type=DataType(Prediction, metadata={"instrument_id": self.target_id.value}) + data_type=DataType(Action, metadata={"instrument_id": self.instrument_id.value}) ) def on_bar(self, bar: Bar): @@ -80,7 +72,7 @@ def on_bar(self, bar: Bar): def on_data(self, data: Data): if isinstance(data, ModelUpdate): self._on_model_update(data) - elif isinstance(data, Prediction): + elif isinstance(data, Action): self._on_prediction(data) else: raise TypeError() @@ -90,14 +82,14 @@ def on_event(self, event: Event): if isinstance(event, (PositionOpened, PositionChanged)): position = self.cache.position(event.position_id) self._log.info(f"{position}", color=LogColor.YELLOW) - assert position.quantity < 200 # Runtime check for bug in code + assert position.quantity < 200 # TODO: Runtime check for bug in code def _on_model_update(self, model_update: ModelUpdate): self.model = model_update.model self.hedge_ratio = model_update.hedge_ratio self.std_pred = model_update.std_prediction - def _on_prediction(self, prediction: Prediction): + def _on_prediction(self, prediction: Action): self.prediction = prediction.prediction self._update_theoretical() @@ -204,13 +196,13 @@ def _check_for_hedge(self, timer=None, event: Optional[Event] = None): return def _hedge_position(self, event: PositionEvent): - # We've opened or changed position in our source instrument, we will likely need to hedge. + # We've opened or changed position in our instrument, we will likely need to hedge. target_position = self.cache.position(event.position_id) hedge_quantity = int(round(target_position.quantity * self.hedge_ratio, 0)) quantity = 0 if isinstance(event, PositionClosed): # (possibly) Reducing our position in the target instrument - source_position: Position = self.current_position(self.source_id) + source_position: Position = self.current_position(self.instrument_id) if source_position is not None and source_position.is_closed: if source_position.id.value not in self._summarised: self._summarise_position() @@ -220,22 +212,26 @@ def _hedge_position(self, event: PositionEvent): else: # (possibly) Increasing our position in hedge instrument side = self._opposite_side(target_position.side) - quantity = self._cap_volume(instrument_id=self.source_id, max_quantity=hedge_quantity) + quantity = self._cap_volume( + instrument_id=self.instrument_id, max_quantity=hedge_quantity + ) if quantity == 0: # Fully hedged, cancel any existing orders - for order in self.cache.orders_open(instrument_id=self.source_id, strategy_id=self.id): + for order in self.cache.orders_open( + instrument_id=self.instrument_id, strategy_id=self.id + ): self.cancel_order(order=order) raise RepeatedEventComplete - elif self.cache.orders_inflight(instrument_id=self.source_id, strategy_id=self.id): + elif self.cache.orders_inflight(instrument_id=self.instrument_id, strategy_id=self.id): # Don't send more orders if we have some currently in-flight return # Cancel any existing orders - for order in self.cache.orders_open(instrument_id=self.source_id, strategy_id=self.id): + for order in self.cache.orders_open(instrument_id=self.instrument_id, strategy_id=self.id): self.cancel_order(order=order) order = self.order_factory.market( - instrument_id=self.source_id, + instrument_id=self.instrument_id, order_side=side, quantity=Quantity.from_int(quantity), ) @@ -289,7 +285,7 @@ def _exit_position(self, bar: Bar): def current_position(self, instrument_id: InstrumentId) -> Optional[Position]: try: - side = {self.source_id: "source", self.target_id: "target"}[instrument_id] + side = {self.instrument_id: "target"}[instrument_id] return self.cache.position(PositionId(f"{side}-{self._position_id}")) except AssertionError: return None @@ -345,5 +341,4 @@ def peak_notional(pos): self._summarised.add(src_pos.id.value) def on_stop(self): - self.close_all_positions(self.source_id) - self.close_all_positions(self.target_id) + self.close_all_positions(self.instrument_id) diff --git a/trade_flow/environments/nt_backtest/sub_training_env/__init__.py b/trade_flow/environments/nt_backtest/sub_training_env/__init__.py new file mode 100644 index 0000000..4e453f4 --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/__init__.py @@ -0,0 +1,104 @@ +""" +Create an OMS for it OR find a way to use the already underlying nautilus trader system +- The first approach seems tedious but simpler + +Provides a default environment, can also be used as specification for writing custom environments. + +Dependencies: + - Generic Environment +""" + +from typing import Union + +from trade_flow.feed import DataFeed +from trade_flow.environments.generic import TradingEnvironment +from trade_flow.environments.generic.components.renderer import AggregateRenderer +from trade_flow.environments.default.oms.portfolio import Portfolio + +from . import actions +from . import rewards +from . import observers +from . import stoppers +from . import informers +from . import renderers + + +def create( + portfolio: Portfolio, + action_scheme: "Union[actions.TradeFlowActionScheme, str]", + reward_scheme: "Union[rewards.TradeFlowRewardScheme, str]", + feed: "DataFeed", + window_size: int = 1, + min_periods: int = None, + random_start_pct: float = 0.00, + **kwargs, +) -> TradingEnvironment: + """Creates the default `TradingEnvironment` of the project to be used in training + RL agents. + + Parameters + ---------- + portfolio : `Portfolio` + The portfolio to be used by the environment. + action_scheme : `actions.TradeFlowActionScheme` or str + The action scheme for computing actions at every step of an episode. + reward_scheme : `rewards.TradeFlowRewardScheme` or str + The reward scheme for computing rewards at every step of an episode. + feed : `DataFeed` + The feed for generating observations to be used in the look back + window. + window_size : int + The size of the look back window to use for the observation space. + min_periods : int, optional + The minimum number of steps to warm up the `feed`. + random_start_pct : float, optional + Whether to randomize the starting point within the environment at each + observer reset, starting in the first X percentage of the sample + **kwargs : keyword arguments + Extra keyword arguments needed to build the environment. + + Returns + ------- + `TradingEnvironment` + The default trading environment. + """ + + action_scheme = actions.get(action_scheme) if isinstance(action_scheme, str) else action_scheme + reward_scheme = rewards.get(reward_scheme) if isinstance(reward_scheme, str) else reward_scheme + + action_scheme.portfolio = portfolio + + observer = observers.TradeFlowObserver( + portfolio=portfolio, + feed=feed, + renderer_feed=kwargs.get("renderer_feed", None), + window_size=window_size, + min_periods=min_periods, + ) + + stopper = stoppers.MaxLossStopper(max_allowed_loss=kwargs.get("max_allowed_loss", 0.5)) + + renderer_list = kwargs.get("renderer", renderers.EmptyRenderer()) + + if isinstance(renderer_list, list): + for i, r in enumerate(renderer_list): + if isinstance(r, str): + renderer_list[i] = renderers.get(r) + renderer = AggregateRenderer(renderer_list) + else: + if isinstance(renderer_list, str): + renderer = renderers.get(renderer_list) + else: + renderer = renderer_list + + env = TradingEnvironment( + action_scheme=action_scheme, + reward_scheme=reward_scheme, + observer=observer, + stopper=kwargs.get("stopper", stopper), + informer=kwargs.get("informer", informers.TradeFlowInformer()), + renderer=renderer, + min_periods=min_periods, + random_start_pct=random_start_pct, + ) + return env diff --git a/trade_flow/environments/nt_backtest/sub_training_env/actions.py b/trade_flow/environments/nt_backtest/sub_training_env/actions.py new file mode 100644 index 0000000..5c18520 --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/actions.py @@ -0,0 +1,433 @@ +import logging +from abc import abstractmethod +from itertools import product +from typing import Union, List, Any + +from gymnasium.spaces import Space, Discrete + +from trade_flow.core import Clock +from trade_flow.environments.generic import ActionScheme, TradingEnvironment + +from trade_flow.environments.default.oms.orders import ( + Broker, + Order, + OrderListener, + OrderSpec, + proportion_order, + risk_managed_order, + TradeSide, + TradeType, +) +from trade_flow.environments.default.oms.portfolio import Portfolio + + +class TradeFlowActionScheme(ActionScheme): + """An abstract base class for any `ActionScheme` that wants to be + compatible with the built in OMS. + + The structure of the action scheme is built to make sure that action space + can be used with the system, provided that the user defines the methods to + interpret that action. + + Attributes + ---------- + portfolio : 'Portfolio' + The portfolio object to be used in defining actions. + broker : 'Broker' + The broker object to be used for placing orders in the OMS. + + Methods + ------- + perform(env,portfolio) + Performs the action on the given environment. + get_orders(action,portfolio) + Gets the list of orders to be submitted for the given action. + """ + + def __init__(self) -> None: + super().__init__() + self.portfolio: "Portfolio" = None + self.broker: "Broker" = Broker() + + @property + def clock(self) -> "Clock": + """The reference clock from the environment. (`Clock`) + + When the clock is set for the we also set the clock for the portfolio + as well as the exchanges defined in the portfolio. + + Returns + ------- + `Clock` + The environment clock. + """ + return self._clock + + @clock.setter + def clock(self, clock: "Clock") -> None: + self._clock = clock + + components = [self.portfolio] + self.portfolio.exchanges + for c in components: + c.clock = clock + self.broker.clock = clock + + def perform(self, env: "TradingEnvironment", action: Any) -> None: + """Performs the action on the given environment. + + Under the TT action scheme, the subclassed action scheme is expected + to provide a method for getting a list of orders to be submitted to + the broker for execution in the OMS. + + Parameters + ---------- + env : 'TradingEnvironment' + The environment to perform the action on. + action : Any + The specific action selected from the action space. + """ + orders = self.get_orders(action, self.portfolio) + + for order in orders: + if order: + logging.info("Step {}: {} {}".format(order.step, order.side, order.quantity)) + self.broker.submit(order) + + self.broker.update() + + @abstractmethod + def get_orders(self, action: Any, portfolio: "Portfolio") -> "List[Order]": + """Gets the list of orders to be submitted for the given action. + + Parameters + ---------- + action : Any + The action to be interpreted. + portfolio : 'Portfolio' + The portfolio defined for the environment. + + Returns + ------- + List[Order] + A list of orders to be submitted to the broker. + """ + raise NotImplementedError() + + def reset(self) -> None: + """Resets the action scheme.""" + self.portfolio.reset() + self.broker.reset() + + +class BSH(TradeFlowActionScheme): + """A simple discrete action scheme where the only options are to buy, sell, + or hold. + + Parameters + ---------- + cash : `Wallet` + The wallet to hold funds in the base intrument. + asset : `Wallet` + The wallet to hold funds in the quote instrument. + """ + + registered_name = "bsh" + + def __init__(self, cash: "Wallet", asset: "Wallet"): + super().__init__() + self.cash = cash + self.asset = asset + + self.listeners = [] + self.action = 0 + + @property + def action_space(self): + return Discrete(2) + + def attach(self, listener): + self.listeners += [listener] + return self + + def get_orders(self, action: int, portfolio: "Portfolio") -> "Order": + order = None + + if abs(action - self.action) > 0: + src = self.cash if self.action == 0 else self.asset + tgt = self.asset if self.action == 0 else self.cash + + if ( + src.balance == 0 + ): # We need to check, regardless of the proposed order, if we have balance in 'src' + return [] # Otherwise just return an empty order list + + order = proportion_order(portfolio, src, tgt, 1.0) + self.action = action + + for listener in self.listeners: + listener.on_action(action) + + return [order] + + def reset(self): + super().reset() + self.action = 0 + + +class SimpleOrders(TradeFlowActionScheme): + """A discrete action scheme that determines actions based on a list of + trading pairs, order criteria, and trade sizes. + + Parameters + ---------- + criteria : List[OrderCriteria] + A list of order criteria to select from when submitting an order. + (e.g. MarketOrder, LimitOrder w/ price, StopLoss, etc.) + trade_sizes : List[float] + A list of trade sizes to select from when submitting an order. + (e.g. '[1, 1/3]' = 100% or 33% of balance is tradable. + '4' = 25%, 50%, 75%, or 100% of balance is tradable.) + durations : List[int] + A list of durations to select from when submitting an order. + trade_type : TradeType + A type of trade to make. + order_listener : OrderListener + A callback class to use for listening to steps of the order process. + min_order_pct : float + The minimum value when placing an order, calculated in percent over net_worth. + min_order_abs : float + The minimum value when placing an order, calculated in absolute order value. + """ + + def __init__( + self, + criteria: "Union[List[OrderCriteria], OrderCriteria]" = None, + trade_sizes: "Union[List[float], int]" = 10, + durations: "Union[List[int], int]" = None, + trade_type: "TradeType" = TradeType.MARKET, + order_listener: "OrderListener" = None, + min_order_pct: float = 0.02, + min_order_abs: float = 0.00, + ) -> None: + super().__init__() + self.min_order_pct = min_order_pct + self.min_order_abs = min_order_abs + criteria = self.default("criteria", criteria) + self.criteria = criteria if isinstance(criteria, list) else [criteria] + + trade_sizes = self.default("trade_sizes", trade_sizes) + if isinstance(trade_sizes, list): + self.trade_sizes = trade_sizes + else: + self.trade_sizes = [(x + 1) / trade_sizes for x in range(trade_sizes)] + + durations = self.default("durations", durations) + self.durations = durations if isinstance(durations, list) else [durations] + + self._trade_type = self.default("trade_type", trade_type) + self._order_listener = self.default("order_listener", order_listener) + + self._action_space = None + self.actions = None + + @property + def action_space(self) -> Space: + if not self._action_space: + self.actions = product( + self.criteria, self.trade_sizes, self.durations, [TradeSide.BUY, TradeSide.SELL] + ) + self.actions = list(self.actions) + self.actions = list(product(self.portfolio.exchange_pairs, self.actions)) + self.actions = [None] + self.actions + + self._action_space = Discrete(len(self.actions)) + return self._action_space + + def get_orders(self, action: int, portfolio: "Portfolio") -> "List[Order]": + + if action == 0: + return [] + + (ep, (criteria, proportion, duration, side)) = self.actions[action] + + instrument = side.instrument(ep.pair) + wallet = portfolio.get_wallet(ep.exchange.id, instrument=instrument) + + balance = wallet.balance.as_float() + size = balance * proportion + size = min(balance, size) + + quantity = (size * instrument).quantize() + + if ( + size < 10**-instrument.precision + or size < self.min_order_pct * portfolio.net_worth + or size < self.min_order_abs + ): + return [] + + order = Order( + step=self.clock.step, + side=side, + trade_type=self._trade_type, + exchange_pair=ep, + price=ep.price, + quantity=quantity, + criteria=criteria, + end=self.clock.step + duration if duration else None, + portfolio=portfolio, + ) + + if self._order_listener is not None: + order.attach(self._order_listener) + + return [order] + + +class ManagedRiskOrders(TradeFlowActionScheme): + """A discrete action scheme that determines actions based on managing risk, + through setting a follow-up stop loss and take profit on every order. + + Parameters + ---------- + stop : List[float] + A list of possible stop loss percentages for each order. + take : List[float] + A list of possible take profit percentages for each order. + trade_sizes : List[float] + A list of trade sizes to select from when submitting an order. + (e.g. '[1, 1/3]' = 100% or 33% of balance is tradable. + '4' = 25%, 50%, 75%, or 100% of balance is tradable.) + durations : List[int] + A list of durations to select from when submitting an order. + trade_type : `TradeType` + A type of trade to make. + order_listener : OrderListener + A callback class to use for listening to steps of the order process. + min_order_pct : float + The minimum value when placing an order, calculated in percent over net_worth. + min_order_abs : float + The minimum value when placing an order, calculated in absolute order value. + """ + + def __init__( + self, + stop: "List[float]" = [0.02, 0.04, 0.06], + take: "List[float]" = [0.01, 0.02, 0.03], + trade_sizes: "Union[List[float], int]" = 10, + durations: "Union[List[int], int]" = None, + trade_type: "TradeType" = TradeType.MARKET, + order_listener: "OrderListener" = None, + min_order_pct: float = 0.02, + min_order_abs: float = 0.00, + ) -> None: + super().__init__() + self.min_order_pct = min_order_pct + self.min_order_abs = min_order_abs + self.stop = self.default("stop", stop) + self.take = self.default("take", take) + + trade_sizes = self.default("trade_sizes", trade_sizes) + if isinstance(trade_sizes, list): + self.trade_sizes = trade_sizes + else: + self.trade_sizes = [(x + 1) / trade_sizes for x in range(trade_sizes)] + + durations = self.default("durations", durations) + self.durations = durations if isinstance(durations, list) else [durations] + + self._trade_type = self.default("trade_type", trade_type) + self._order_listener = self.default("order_listener", order_listener) + + self._action_space = None + self.actions = None + + @property + def action_space(self) -> "Space": + if not self._action_space: + self.actions = product( + self.stop, + self.take, + self.trade_sizes, + self.durations, + [TradeSide.BUY, TradeSide.SELL], + ) + self.actions = list(self.actions) + self.actions = list(product(self.portfolio.exchange_pairs, self.actions)) + self.actions = [None] + self.actions + + self._action_space = Discrete(len(self.actions)) + return self._action_space + + def get_orders(self, action: int, portfolio: "Portfolio") -> "List[Order]": + + if action == 0: + return [] + + (ep, (stop, take, proportion, duration, side)) = self.actions[action] + + side = TradeSide(side) + + instrument = side.instrument(ep.pair) + wallet = portfolio.get_wallet(ep.exchange.id, instrument=instrument) + + balance = wallet.balance.as_float() + size = balance * proportion + size = min(balance, size) + quantity = (size * instrument).quantize() + + if ( + size < 10**-instrument.precision + or size < self.min_order_pct * portfolio.net_worth + or size < self.min_order_abs + ): + return [] + + params = { + "side": side, + "exchange_pair": ep, + "price": ep.price, + "quantity": quantity, + "down_percent": stop, + "up_percent": take, + "portfolio": portfolio, + "trade_type": self._trade_type, + "end": self.clock.step + duration if duration else None, + } + + order = risk_managed_order(**params) + + if self._order_listener is not None: + order.attach(self._order_listener) + + return [order] + + +_registry = { + "bsh": BSH, + "simple": SimpleOrders, + "managed-risk": ManagedRiskOrders, +} + + +def get(identifier: str) -> "ActionScheme": + """Gets the `ActionScheme` that matches with the identifier. + + Parameters + ---------- + identifier : str + The identifier for the `ActionScheme`. + + Returns + ------- + 'ActionScheme' + The action scheme associated with the `identifier`. + + Raises + ------ + KeyError: + Raised if the `identifier` is not associated with any `ActionScheme`. + """ + if identifier not in _registry.keys(): + raise KeyError(f"Identifier {identifier} is not associated with any `ActionScheme`.") + return _registry[identifier]() diff --git a/trade_flow/environments/nt_backtest/sub_training_env/informers.py b/trade_flow/environments/nt_backtest/sub_training_env/informers.py new file mode 100644 index 0000000..d8e2105 --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/informers.py @@ -0,0 +1,10 @@ +from trade_flow.environments.generic import Informer, TradingEnvironment + + +class TradeFlowInformer(Informer): + + def __init__(self) -> None: + super().__init__() + + def info(self, env: "TradingEnvironment") -> dict: + return {"step": self.clock.step, "net_worth": env.action_scheme.portfolio.net_worth} diff --git a/trade_flow/environments/nt_backtest/sub_training_env/observers.py b/trade_flow/environments/nt_backtest/sub_training_env/observers.py new file mode 100644 index 0000000..f2877cf --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/observers.py @@ -0,0 +1,461 @@ +from typing import List + + +import datetime as dt +import numpy as np +import pandas as pd + +from gymnasium.spaces import Box, Space +from random import randrange + + +from trade_flow.feed import Stream, NameSpace, DataFeed +from trade_flow.environments.default.oms.wallet import Wallet +from trade_flow.environments.default.oms.portfolio import Portfolio +from trade_flow.environments.generic import Observer +from collections import OrderedDict + + +def _create_wallet_source(wallet: "Wallet", include_worth: bool = True) -> "List[Stream[float]]": + """Creates a list of streams to describe a `Wallet`. + + Parameters + ---------- + wallet : `Wallet` + The wallet to make streams for. + include_worth : bool, default True + Whether or + + Returns + ------- + `List[Stream[float]]` + A list of streams to describe the `wallet`. + """ + venue_name = wallet.exchange.name + symbol = wallet.instrument.symbol + + streams = [] + + with NameSpace(venue_name + ":/" + symbol): + free_balance = Stream.sensor(wallet, lambda w: w.balance.as_float(), dtype="float").rename( + "free" + ) + locked_balance = Stream.sensor( + wallet, lambda w: w.locked_balance.as_float(), dtype="float" + ).rename("locked") + total_balance = Stream.sensor( + wallet, lambda w: w.total_balance.as_float(), dtype="float" + ).rename("total") + + streams += [free_balance, locked_balance, total_balance] + + if include_worth: + price = Stream.select( + wallet.exchange.streams(), lambda node: node.name.endswith(symbol) + ) + worth = price.mul(total_balance).rename("worth") + streams += [worth] + + return streams + + +def _create_internal_streams(portfolio: "Portfolio") -> "List[Stream[float]]": + """Creates a list of streams to describe a `Portfolio`. + + Parameters + ---------- + portfolio : `Portfolio` + The portfolio to make the streams for. + + Returns + ------- + `List[Stream[float]]` + A list of streams to describe the `portfolio`. + """ + base_symbol = portfolio.base_instrument.symbol + sources = [] + + for wallet in portfolio.wallets: + symbol = wallet.instrument.symbol + sources += wallet.exchange.streams() + sources += _create_wallet_source(wallet, include_worth=(symbol != base_symbol)) + + worth_streams = [] + for s in sources: + if s.name.endswith(base_symbol + ":/total") or s.name.endswith("worth"): + worth_streams += [s] + + net_worth = Stream.reduce(worth_streams).sum().rename("net_worth") + sources += [net_worth] + + return sources + + +class ObservationHistory(object): + """Stores observations from a given episode of the environment. + + Parameters + ---------- + window_size : int + The amount of observations to keep stored before discarding them. + + Attributes + ---------- + window_size : int + The amount of observations to keep stored before discarding them. + rows : pd.DataFrame + The rows of observations that are used as the environment observation + at each step of an episode. + + """ + + def __init__(self, window_size: int) -> None: + self.window_size = window_size + self.rows = OrderedDict() + self.index = 0 + + def push(self, row: dict) -> None: + """Stores an observation. + + Parameters + ---------- + row : dict + The new observation to store. + """ + self.rows[self.index] = row + self.index += 1 + if len(self.rows.keys()) > self.window_size: + del self.rows[list(self.rows.keys())[0]] + + def observe(self) -> "np.array": + """Gets the observation at a given step in an episode + + Returns + ------- + `np.array` + The current observation of the environment. + """ + rows = self.rows.copy() + + if len(rows) < self.window_size: + size = self.window_size - len(rows) + padding = np.zeros((size, len(rows[list(rows.keys())[0]]))) + r = np.array([list(inner_dict.values()) for inner_dict in rows.values()]) + rows = np.concatenate((padding, r)) + + if isinstance(rows, OrderedDict): + rows = np.array([list(inner_dict.values()) for inner_dict in rows.values()]) + + rows = np.nan_to_num(rows) + + return rows + + def reset(self) -> None: + """Resets the observation history""" + self.rows = OrderedDict() + self.index = 0 + + +class TradeFlowObserver(Observer): + """The TradeFlow observer that is compatible with the other `default` + components. + + Parameters + ---------- + portfolio : `Portfolio` + The portfolio to be used to create the internal data feed mechanism. + feed : `DataFeed` + The feed to be used to collect observations to the observation window. + renderer_feed : `DataFeed` + The feed to be used for giving information to the renderer. + window_size : int + The size of the observation window. + min_periods : int + The amount of steps needed to warmup the `feed`. + **kwargs : keyword arguments + Additional keyword arguments for observer creation. + + Attributes + ---------- + feed : `DataFeed` + The master feed in charge of streaming the internal, external, and + renderer data feeds. + window_size : int + The size of the observation window. + min_periods : int + The amount of steps needed to warmup the `feed`. + history : `ObservationHistory` + The observation history. + renderer_history : `List[dict]` + The history of the renderer data feed. + """ + + def __init__( + self, + portfolio: "Portfolio", + feed: "DataFeed" = None, + renderer_feed: "DataFeed" = None, + window_size: int = 1, + min_periods: int = None, + **kwargs, + ) -> None: + internal_group = Stream.group(_create_internal_streams(portfolio)).rename("internal") + external_group = Stream.group(feed.inputs).rename("external") + + if renderer_feed: + renderer_group = Stream.group(renderer_feed.inputs).rename("renderer") + + self.feed = DataFeed([internal_group, external_group, renderer_group]) + else: + self.feed = DataFeed([internal_group, external_group]) + + self.window_size = window_size + self.min_periods = min_periods + + self._observation_dtype = kwargs.get("dtype", np.float32) + self._observation_lows = kwargs.get("observation_lows", -np.inf) + self._observation_highs = kwargs.get("observation_highs", np.inf) + + self.history = ObservationHistory(window_size=window_size) + + initial_obs = self.feed.next()["external"] + n_features = len(initial_obs.keys()) + + self._observation_space = Box( + low=self._observation_lows, + high=self._observation_highs, + shape=(self.window_size, n_features), + dtype=self._observation_dtype, + ) + + self.feed = self.feed.attach(portfolio) + + self.renderer_history = [] + + self.feed.reset() + self.warmup() + + @property + def observation_space(self) -> Space: + return self._observation_space + + def warmup(self) -> None: + """Warms up the data feed.""" + if self.min_periods is not None: + for _ in range(self.min_periods): + if self.has_next(): + obs_row = self.feed.next()["external"] + self.history.push(obs_row) + + def observe(self, env: "TradingEnvironment") -> np.array: + """Observes the environment. + + As a consequence of observing the `env`, a new observation is generated + from the `feed` and stored in the observation history. + + Returns + ------- + `np.array` + The current observation of the environment. + """ + data = self.feed.next() + + # Save renderer information to history + if "renderer" in data.keys(): + self.renderer_history += [data["renderer"]] + + # Push new observation to observation history + obs_row = data["external"] + self.history.push(obs_row) + + obs = self.history.observe() + obs = obs.astype(self._observation_dtype) + return obs + + def has_next(self) -> bool: + """Checks if there is another observation to be generated. + + Returns + ------- + bool + Whether there is another observation to be generated. + """ + return self.feed.has_next() + + def reset(self, random_start=0) -> None: + """Resets the observer""" + self.renderer_history = [] + self.history.reset() + self.feed.reset(random_start) + self.warmup() + + +class IntradayObserver(Observer): + """The IntradayObserver observer that is compatible with the other `default` + components. + Parameters + ---------- + portfolio : `Portfolio` + The portfolio to be used to create the internal data feed mechanism. + feed : `DataFeed` + The feed to be used to collect observations to the observation window. + renderer_feed : `DataFeed` + The feed to be used for giving information to the renderer. + stop_time : datetime.time + The time at which the episode will stop. + window_size : int + The size of the observation window. + min_periods : int + The amount of steps needed to warmup the `feed`. + randomize : bool + Whether or not to select a random episode when reset. + **kwargs : keyword arguments + Additional keyword arguments for observer creation. + Attributes + ---------- + feed : `DataFeed` + The master feed in charge of streaming the internal, external, and + renderer data feeds. + stop_time : datetime.time + The time at which the episode will stop. + window_size : int + The size of the observation window. + min_periods : int + The amount of steps needed to warmup the `feed`. + randomize : bool + Whether or not a random episode is selected when reset. + history : `ObservationHistory` + The observation history. + renderer_history : `List[dict]` + The history of the renderer data feed. + """ + + def __init__( + self, + portfolio: "Portfolio", + feed: "DataFeed" = None, + renderer_feed: "DataFeed" = None, + stop_time: "datetime.time" = dt.time(16, 0, 0), + window_size: int = 1, + min_periods: int = None, + randomize: bool = False, + **kwargs, + ) -> None: + internal_group = Stream.group(_create_internal_streams(portfolio)).rename("internal") + external_group = Stream.group(feed.inputs).rename("external") + + if renderer_feed: + renderer_group = Stream.group(renderer_feed.inputs).rename("renderer") + + self.feed = DataFeed([internal_group, external_group, renderer_group]) + else: + self.feed = DataFeed([internal_group, external_group]) + + self.stop_time = stop_time + self.window_size = window_size + self.min_periods = min_periods + self.randomize = randomize + + self._observation_dtype = kwargs.get("dtype", np.float32) + self._observation_lows = kwargs.get("observation_lows", -np.inf) + self._observation_highs = kwargs.get("observation_highs", np.inf) + + self.history = ObservationHistory(window_size=window_size) + + initial_obs = self.feed.next()["external"] + initial_obs.pop("timestamp", None) + n_features = len(initial_obs.keys()) + + self._observation_space = Box( + low=self._observation_lows, + high=self._observation_highs, + shape=(self.window_size, n_features), + dtype=self._observation_dtype, + ) + + self.feed = self.feed.attach(portfolio) + + self.renderer_history = [] + + if self.randomize: + self.num_episodes = 0 + while self.feed.has_next(): + ts = self.feed.next()["external"]["timestamp"] + if ts.time() == self.stop_time: + self.num_episodes += 1 + + self.feed.reset() + self.warmup() + + self.stop = False + + @property + def observation_space(self) -> Space: + return self._observation_space + + def warmup(self) -> None: + """Warms up the data feed.""" + if self.min_periods is not None: + for _ in range(self.min_periods): + if self.has_next(): + obs_row = self.feed.next()["external"] + obs_row.pop("timestamp", None) + self.history.push(obs_row) + + def observe(self, env: "TradingEnvironment") -> np.array: + """Observes the environment. + As a consequence of observing the `env`, a new observation is generated + from the `feed` and stored in the observation history. + Returns + ------- + `np.array` + The current observation of the environment. + """ + data = self.feed.next() + + # Save renderer information to history + if "renderer" in data.keys(): + self.renderer_history += [data["renderer"]] + + # Push new observation to observation history + obs_row = data["external"] + try: + obs_ts = obs_row.pop("timestamp") + except KeyError: + raise KeyError("Include Stream of Timestamps named 'timestamp' in feed") + self.history.push(obs_row) + + # Check if episode should be stopped + if obs_ts.time() == self.stop_time: + self.stop = True + + obs = self.history.observe() + obs = obs.astype(self._observation_dtype) + return obs + + def has_next(self) -> bool: + """Checks if there is another observation to be generated. + Returns + ------- + bool + Whether there is another observation to be generated. + """ + return self.feed.has_next() and not self.stop + + def reset(self) -> None: + """Resets the observer""" + self.renderer_history = [] + self.history.reset() + + if self.randomize or not self.feed.has_next(): + self.feed.reset() + if self.randomize: + episode_num = 0 + while episode_num < randrange(self.num_episodes): + ts = self.feed.next()["external"]["timestamp"] + if ts.time() == self.stop_time: + episode_num += 1 + + self.warmup() + + self.stop = False diff --git a/trade_flow/environments/nt_backtest/sub_training_env/renderers.py b/trade_flow/environments/nt_backtest/sub_training_env/renderers.py new file mode 100644 index 0000000..dca1f17 --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/renderers.py @@ -0,0 +1,834 @@ +import os +import sys +import logging +import importlib + +from abc import abstractmethod +from datetime import datetime +from typing import Union, Tuple +from collections import OrderedDict + +import numpy as np +import pandas as pd + +from IPython.display import display, clear_output +from pandas.plotting import register_matplotlib_converters + +from trade_flow.environments.default.oms.orders import TradeSide +from trade_flow.environments.generic import Renderer, TradingEnvironment + + +if importlib.util.find_spec("matplotlib"): + import matplotlib.pyplot as plt + + from matplotlib import style + + style.use("ggplot") + register_matplotlib_converters() + +if importlib.util.find_spec("plotly"): + import plotly.graph_objects as go + + from plotly.subplots import make_subplots + + +def _create_auto_file_name( + filename_prefix: str, ext: str, timestamp_format: str = "%Y%m%d_%H%M%S" +) -> str: + timestamp = datetime.now().strftime(timestamp_format) + filename = filename_prefix + timestamp + "." + ext + return filename + + +def _check_path(path: str, auto_create: bool = True) -> None: + if not path or os.path.exists(path): + return + + if auto_create: + os.mkdir(path) + else: + raise OSError(f"Path '{path}' not found.") + + +def _check_valid_format(valid_formats: list, save_format: str) -> None: + if save_format not in valid_formats: + raise ValueError( + "Acceptable formats are '{}'. Found '{}'".format( + "', '".join(valid_formats), save_format + ) + ) + + +class BaseRenderer(Renderer): + """The abstract base renderer to be subclassed when making a renderer + the incorporates a `Portfolio`. + """ + + def __init__(self): + super().__init__() + self._max_episodes = None + self._max_steps = None + + @staticmethod + def _create_log_entry( + episode: int = None, + max_episodes: int = None, + step: int = None, + max_steps: int = None, + date_format: str = "%Y-%m-%d %H:%M:%S", + ) -> str: + """ + Creates a log entry to be used by a renderer. + + Parameters + ---------- + episode : int + The current episode. + max_episodes : int + The maximum number of episodes that can occur. + step : int + The current step of the current episode. + max_steps : int + The maximum number of steps within an episode that can occur. + date_format : str + The format for logging the date. + + Returns + ------- + str + a log entry + """ + log_entry = f"[{datetime.now().strftime(date_format)}]" + + if episode is not None: + log_entry += f" Episode: {episode + 1}/{max_episodes if max_episodes else ''}" + + if step is not None: + log_entry += f" Step: {step}/{max_steps if max_steps else ''}" + + return log_entry + + def render(self, env: "TradingEnvironment", **kwargs): + + price_history = None + if len(env.observer.renderer_history) > 0: + price_history = pd.DataFrame(env.observer.renderer_history) + + performance = pd.DataFrame.from_dict( + env.action_scheme.portfolio.performance, orient="index" + ) + + self.render_env( + episode=kwargs.get("episode", None), + max_episodes=kwargs.get("max_episodes", None), + step=env.clock.step, + max_steps=kwargs.get("max_steps", None), + price_history=price_history, + net_worth=performance.net_worth, + performance=performance.drop(columns=["base_symbol"]), + trades=env.action_scheme.broker.trades, + ) + + @abstractmethod + def render_env( + self, + episode: int = None, + max_episodes: int = None, + step: int = None, + max_steps: int = None, + price_history: "pd.DataFrame" = None, + net_worth: "pd.Series" = None, + performance: "pd.DataFrame" = None, + trades: "OrderedDict" = None, + ) -> None: + """Renderers the current state of the environment. + + Parameters + ---------- + episode : int + The episode that the environment is being rendered for. + max_episodes : int + The maximum number of episodes that will occur. + step : int + The step of the current episode that is happening. + max_steps : int + The maximum number of steps that will occur in an episode. + price_history : `pd.DataFrame` + The history of instrument involved with the environment. The + required columns are: date, open, high, low, close, and volume. + net_worth : `pd.Series` + The history of the net worth of the `portfolio`. + performance : `pd.Series` + The history of performance of the `portfolio`. + trades : `OrderedDict` + The history of trades for the current episode. + """ + raise NotImplementedError() + + def save(self) -> None: + """Saves the rendering of the `TradingEnvironment`.""" + pass + + def reset(self) -> None: + """Resets the renderer.""" + pass + + +class EmptyRenderer(Renderer): + """A renderer that does renders nothing. + + Needed to make sure that environment can function without requiring a + renderer. + """ + + def render(self, env, **kwargs): + pass + + +class ScreenLogger(BaseRenderer): + """Logs information the screen of the user. + + Parameters + ---------- + date_format : str + The format for logging the date. + """ + + DEFAULT_FORMAT: str = "[%(asctime)-15s] %(message)s" + + def __init__(self, date_format: str = "%Y-%m-%d %H:%M:%S"): + super().__init__() + self._date_format = date_format + + def render_env( + self, + episode: int = None, + max_episodes: int = None, + step: int = None, + max_steps: int = None, + price_history: pd.DataFrame = None, + net_worth: pd.Series = None, + performance: pd.DataFrame = None, + trades: "OrderedDict" = None, + ): + print( + self._create_log_entry( + episode, max_episodes, step, max_steps, date_format=self._date_format + ) + ) + + +class FileLogger(BaseRenderer): + """Logs information to a file. + + Parameters + ---------- + filename : str + The file name of the log file. If omitted, a file name will be + created automatically. + path : str + The path to save the log files to. None to save to same script directory. + log_format : str + The log entry format as per Python logging. None for default. For + more details, refer to https://docs.python.org/3/library/logging.html + timestamp_format : str + The format of the timestamp of the log entry. Node for default. + """ + + DEFAULT_LOG_FORMAT: str = "[%(asctime)-15s] %(message)s" + DEFAULT_TIMESTAMP_FORMAT: str = "%Y-%m-%d %H:%M:%S" + + def __init__( + self, + filename: str = None, + path: str = "log", + log_format: str = None, + timestamp_format: str = None, + ) -> None: + super().__init__() + _check_path(path) + + if not filename: + filename = _create_auto_file_name("log_", "log") + + self._logger = logging.getLogger(self.id) + self._logger.setLevel(logging.INFO) + + if path: + filename = os.path.join(path, filename) + handler = logging.FileHandler(filename) + handler.setFormatter( + logging.Formatter( + log_format if log_format is not None else self.DEFAULT_LOG_FORMAT, + datefmt=( + timestamp_format + if timestamp_format is not None + else self.DEFAULT_TIMESTAMP_FORMAT + ), + ) + ) + self._logger.addHandler(handler) + + @property + def log_file(self) -> str: + """The filename information is being logged to. (str, read-only)""" + return self._logger.handlers[0].baseFilename + + def render_env( + self, + episode: int = None, + max_episodes: int = None, + step: int = None, + max_steps: int = None, + price_history: pd.DataFrame = None, + net_worth: pd.Series = None, + performance: pd.DataFrame = None, + trades: "OrderedDict" = None, + ) -> None: + log_entry = self._create_log_entry(episode, max_episodes, step, max_steps) + self._logger.info(f"{log_entry} - Performance:\n{performance}") + + +class PlotlyTradingChart(BaseRenderer): + """Trading visualization for TradeFlow using Plotly. + + Parameters + ---------- + display : bool + True to display the chart on the screen, False for not. + height : int + Chart height in pixels. Affects both display and saved file + charts. Set to None for 100% height. Default is None. + save_format : str + A format to save the chart to. Acceptable formats are + html, png, jpeg, webp, svg, pdf, eps. All the formats except for + 'html' require Orca. Default is None for no saving. + path : str + The path to save the char to if save_format is not None. The folder + will be created if not found. + filename_prefix : str + A string that precedes automatically-created file name + when charts are saved. Default 'chart_'. + timestamp_format : str + The format of the date shown in the chart title. + auto_open_html : bool + Works for save_format='html' only. True to automatically + open the saved chart HTML file in the default browser, False otherwise. + include_plotlyjs : Union[bool, str] + Whether to include/load the plotly.js library in the saved + file. 'cdn' results in a smaller file by loading the library online but + requires an Internet connect while True includes the library resulting + in much larger file sizes. False to not include the library. For more + details, refer to https://plot.ly/python-api-reference/generated/plotly.graph_objects.Figure.html + + Notes + ----- + Possible Future Enhancements: + - Saving images without using Orca. + - Limit displayed step range for the case of a large number of steps and let + the shown part of the chart slide after filling that range to keep showing + recent data as it's being added. + + References + ---------- + .. [1] https://plot.ly/python-api-reference/generated/plotly.graph_objects.Figure.html + .. [2] https://plot.ly/python/figurewidget/ + .. [3] https://plot.ly/python/subplots/ + .. [4] https://plot.ly/python/reference/#candlestick + .. [5] https://plot.ly/python/#chart-events + """ + + def __init__( + self, + display: bool = True, + height: int = None, + timestamp_format: str = "%Y-%m-%d %H:%M:%S", + save_format: str = None, + path: str = "charts", + filename_prefix: str = "chart_", + auto_open_html: bool = False, + include_plotlyjs: Union[bool, str] = "cdn", + ) -> None: + super().__init__() + self._height = height + self._timestamp_format = timestamp_format + self._save_format = save_format + self._path = path + self._filename_prefix = filename_prefix + self._include_plotlyjs = include_plotlyjs + self._auto_open_html = auto_open_html + + if self._save_format and self._path and not os.path.exists(path): + os.mkdir(path) + + self.fig = None + self._price_chart = None + self._volume_chart = None + self._performance_chart = None + self._net_worth_chart = None + self._base_annotations = None + self._last_trade_step = 0 + self._show_chart = display + + def _create_figure(self, performance_keys: dict) -> None: + fig = make_subplots( + rows=4, + cols=1, + shared_xaxes=True, + vertical_spacing=0.03, + row_heights=[0.55, 0.15, 0.15, 0.15], + ) + fig.add_trace( + go.Candlestick(name="Price", xaxis="x1", yaxis="y1", showlegend=False), row=1, col=1 + ) + fig.update_layout(xaxis_rangeslider_visible=False) + + fig.add_trace( + go.Bar(name="Volume", showlegend=False, marker={"color": "DodgerBlue"}), row=2, col=1 + ) + + for k in performance_keys: + fig.add_trace(go.Scatter(mode="lines", name=k), row=3, col=1) + + fig.add_trace( + go.Scatter(mode="lines", name="Net Worth", marker={"color": "DarkGreen"}), row=4, col=1 + ) + + fig.update_xaxes(linecolor="Grey", gridcolor="Gainsboro") + fig.update_yaxes(linecolor="Grey", gridcolor="Gainsboro") + fig.update_xaxes(title_text="Price", row=1) + fig.update_xaxes(title_text="Volume", row=2) + fig.update_xaxes(title_text="Performance", row=3) + fig.update_xaxes(title_text="Net Worth", row=4) + fig.update_xaxes(title_standoff=7, title_font=dict(size=12)) + + self.fig = go.FigureWidget(fig) + self._price_chart = self.fig.data[0] + self._volume_chart = self.fig.data[1] + self._performance_chart = self.fig.data[2] + self._net_worth_chart = self.fig.data[-1] + + self.fig.update_annotations({"font": {"size": 12}}) + self.fig.update_layout(template="plotly_white", height=self._height, margin=dict(t=50)) + self._base_annotations = self.fig.layout.annotations + + def _create_trade_annotations( + self, trades: "OrderedDict", price_history: "pd.DataFrame" + ) -> "Tuple[go.layout.Annotation]": + """Creates annotations of the new trades after the last one in the chart. + + Parameters + ---------- + trades : `OrderedDict` + The history of trades for the current episode. + price_history : `pd.DataFrame` + The price history of the current episode. + + Returns + ------- + `Tuple[go.layout.Annotation]` + A tuple of annotations used in the renderering process. + """ + annotations = [] + for trade in reversed(trades.values()): + trade = trade[0] + + tp = float(trade.price) + ts = float(trade.size) + + if trade.step <= self._last_trade_step: + break + + if trade.side.value == "buy": + color = "DarkGreen" + ay = 15 + qty = round(ts / tp, trade.quote_instrument.precision) + + text_info = dict( + step=trade.step, + datetime=price_history.iloc[trade.step - 1]["date"], + side=trade.side.value.upper(), + qty=qty, + size=ts, + quote_instrument=trade.quote_instrument, + price=tp, + base_instrument=trade.base_instrument, + type=trade.type.value.upper(), + commission=trade.commission, + ) + + elif trade.side.value == "sell": + color = "FireBrick" + ay = -15 + # qty = round(ts * tp, trade.quote_instrument.precision) + + text_info = dict( + step=trade.step, + datetime=price_history.iloc[trade.step - 1]["date"], + side=trade.side.value.upper(), + qty=ts, + size=round(ts * tp, trade.base_instrument.precision), + quote_instrument=trade.quote_instrument, + price=tp, + base_instrument=trade.base_instrument, + type=trade.type.value.upper(), + commission=trade.commission, + ) + else: + raise ValueError( + f"Valid trade side values are 'buy' and 'sell'. Found '{trade.side.value}'." + ) + + hovertext = ( + "Step {step} [{datetime}]
" + "{side} {qty} {quote_instrument} @ {price} {base_instrument} {type}
" + "Total: {size} {base_instrument} - Comm.: {commission}".format(**text_info) + ) + + annotations += [ + go.layout.Annotation( + x=trade.step - 1, + y=tp, + ax=0, + ay=ay, + xref="x1", + yref="y1", + showarrow=True, + arrowhead=2, + arrowcolor=color, + arrowwidth=4, + arrowsize=0.8, + hovertext=hovertext, + opacity=0.6, + hoverlabel=dict(bgcolor=color), + ) + ] + + if trades: + self._last_trade_step = trades[list(trades)[-1]][0].step + + return tuple(annotations) + + def render_env( + self, + episode: int = None, + max_episodes: int = None, + step: int = None, + max_steps: int = None, + price_history: pd.DataFrame = None, + net_worth: pd.Series = None, + performance: pd.DataFrame = None, + trades: "OrderedDict" = None, + ) -> None: + if price_history is None: + raise ValueError("renderers() is missing required positional argument 'price_history'.") + + if net_worth is None: + raise ValueError("renderers() is missing required positional argument 'net_worth'.") + + if performance is None: + raise ValueError("renderers() is missing required positional argument 'performance'.") + + if trades is None: + raise ValueError("renderers() is missing required positional argument 'trades'.") + + if not self.fig: + self._create_figure(performance.keys()) + + if self._show_chart: # ensure chart visibility through notebook cell reruns + display(self.fig) + + self.fig.layout.title = self._create_log_entry(episode, max_episodes, step, max_steps) + self._price_chart.update( + dict( + open=price_history["open"], + high=price_history["high"], + low=price_history["low"], + close=price_history["close"], + ) + ) + self.fig.layout.annotations += self._create_trade_annotations(trades, price_history) + + self._volume_chart.update({"y": price_history["volume"]}) + + for trace in self.fig.select_traces(row=3): + trace.update({"y": performance[trace.name]}) + + self._net_worth_chart.update({"y": net_worth}) + + if self._show_chart: + self.fig.show() + + def save(self) -> None: + """Saves the current chart to a file. + + Notes + ----- + All formats other than HTML require Orca installed and server running. + """ + if not self._save_format: + return + else: + valid_formats = ["html", "png", "jpeg", "webp", "svg", "pdf", "eps"] + _check_valid_format(valid_formats, self._save_format) + + _check_path(self._path) + + filename = _create_auto_file_name(self._filename_prefix, self._save_format) + filename = os.path.join(self._path, filename) + if self._save_format == "html": + self.fig.write_html( + file=filename, include_plotlyjs="cdn", auto_open=self._auto_open_html + ) + else: + self.fig.write_image(filename) + + def reset(self) -> None: + self._last_trade_step = 0 + if self.fig is None: + return + + self.fig.layout.annotations = self._base_annotations + clear_output(wait=True) + + +class MatplotlibTradingChart(BaseRenderer): + """Trading visualization for TradeFlow using Matplotlib + Parameters + --------- + display : bool + True to display the chart on the screen, False for not. + save_format : str + A format to save the chart to. Acceptable formats are + png, jpg, svg, pdf. + path : str + The path to save the char to if save_format is not None. The folder + will be created if not found. + filename_prefix : str + A string that precedes automatically-created file name + when charts are saved. Default 'chart_'. + """ + + def __init__( + self, + display: bool = True, + save_format: str = None, + path: str = "charts", + filename_prefix: str = "chart_", + ) -> None: + super().__init__() + self._volume_chart_height = 0.33 + + self._df = None + self.fig = None + self._price_ax = None + self._volume_ax = None + self.net_worth_ax = None + self._show_chart = display + + self._save_format = save_format + self._path = path + self._filename_prefix = filename_prefix + + if self._save_format and self._path and not os.path.exists(path): + os.mkdir(path) + + def _create_figure(self) -> None: + self.fig = plt.figure() + + self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1) + self.price_ax = plt.subplot2grid( + (6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.net_worth_ax + ) + self.volume_ax = self.price_ax.twinx() + plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) + + def _render_trades(self, step_range, trades) -> None: + trades = [trade for sublist in trades.values() for trade in sublist] + + for trade in trades: + if trade.step in range(sys.maxsize)[step_range]: + date = self._df.index.values[trade.step] + close = self._df["close"].values[trade.step] + color = "green" + + if trade.side is TradeSide.SELL: + color = "red" + + self.price_ax.annotate( + " ", + (date, close), + xytext=(date, close), + size="large", + arrowprops=dict(arrowstyle="simple", facecolor=color), + ) + + def _render_volume(self, step_range, times) -> None: + self.volume_ax.clear() + + volume = np.array(self._df["volume"].values[step_range]) + + self.volume_ax.plot(times, volume, color="blue") + self.volume_ax.fill_between(times, volume, color="blue", alpha=0.5) + + self.volume_ax.set_ylim(0, max(volume) / self._volume_chart_height) + self.volume_ax.yaxis.set_ticks([]) + + def _render_price(self, step_range, times, current_step) -> None: + self.price_ax.clear() + + self.price_ax.plot(times, self._df["close"].values[step_range], color="black") + + last_time = self._df.index.values[current_step] + last_close = self._df["close"].values[current_step] + last_high = self._df["high"].values[current_step] + + self.price_ax.annotate( + "{0:.2f}".format(last_close), + (last_time, last_close), + xytext=(last_time, last_high), + bbox=dict(boxstyle="round", fc="w", ec="k", lw=1), + color="black", + fontsize="small", + ) + + ylim = self.price_ax.get_ylim() + self.price_ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) * self._volume_chart_height, ylim[1]) + + # def _render_net_worth(self, step_range, times, current_step, net_worths, benchmarks): + def _render_net_worth(self, step_range, times, current_step, net_worths) -> None: + self.net_worth_ax.clear() + self.net_worth_ax.plot(times, net_worths[step_range], label="Net Worth", color="g") + self.net_worth_ax.legend() + + legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={"size": 8}) + legend.get_frame().set_alpha(0.4) + + last_time = times[-1] + last_net_worth = list(net_worths[step_range])[-1] + + self.net_worth_ax.annotate( + "{0:.2f}".format(last_net_worth), + (last_time, last_net_worth), + xytext=(last_time, last_net_worth), + bbox=dict(boxstyle="round", fc="w", ec="k", lw=1), + color="black", + fontsize="small", + ) + + self.net_worth_ax.set_ylim(min(net_worths) / 1.25, max(net_worths) * 1.25) + + def render_env( + self, + episode: int = None, + max_episodes: int = None, + step: int = None, + max_steps: int = None, + price_history: "pd.DataFrame" = None, + net_worth: "pd.Series" = None, + performance: "pd.DataFrame" = None, + trades: "OrderedDict" = None, + ) -> None: + if price_history is None: + raise ValueError("renderers() is missing required positional argument 'price_history'.") + + if net_worth is None: + raise ValueError("renderers() is missing required positional argument 'net_worth'.") + + if performance is None: + raise ValueError("renderers() is missing required positional argument 'performance'.") + + if trades is None: + raise ValueError("renderers() is missing required positional argument 'trades'.") + + if not self.fig: + self._create_figure() + + if self._show_chart: + plt.show(block=False) + + current_step = step - 1 + + self._df = price_history + if max_steps: + window_size = max_steps + else: + window_size = 20 + + current_net_worth = round(net_worth[len(net_worth) - 1], 1) + initial_net_worth = round(net_worth[0], 1) + profit_percent = round((current_net_worth - initial_net_worth) / initial_net_worth * 100, 2) + + self.fig.suptitle( + "Net worth: $" + str(current_net_worth) + " | Profit: " + str(profit_percent) + "%" + ) + + window_start = max(current_step - window_size, 0) + step_range = slice(window_start, current_step) + + times = self._df.index.values[step_range] + + if len(times) > 0: + # self._render_net_worth(step_range, times, current_step, net_worths, benchmarks) + self._render_net_worth(step_range, times, current_step, net_worth) + self._render_price(step_range, times, current_step) + self._render_volume(step_range, times) + self._render_trades(step_range, trades) + + self.price_ax.set_xticklabels(times, rotation=45, horizontalalignment="right") + + plt.setp(self.net_worth_ax.get_xticklabels(), visible=False) + plt.pause(0.001) + + def save(self) -> None: + """Saves the rendering of the `TradingEnvironment`.""" + if not self._save_format: + return + else: + valid_formats = ["png", "jpeg", "svg", "pdf"] + _check_valid_format(valid_formats, self._save_format) + + _check_path(self._path) + filename = _create_auto_file_name(self._filename_prefix, self._save_format) + filename = os.path.join(self._path, filename) + self.fig.savefig(filename, format=self._save_format) + + def reset(self) -> None: + """Resets the renderer.""" + self.fig = None + self._price_ax = None + self._volume_ax = None + self.net_worth_ax = None + self._df = None + + +_registry = { + "screen-log": ScreenLogger, + "file-log": FileLogger, + "plotly": PlotlyTradingChart, + "matplot": MatplotlibTradingChart, +} + + +def get(identifier: str) -> "BaseRenderer": + """Gets the `BaseRenderer` that matches the identifier. + + Parameters + ---------- + identifier : str + The identifier for the `BaseRenderer` + + Returns + ------- + `BaseRenderer` + The renderer associated with the `identifier`. + + Raises + ------ + KeyError: + Raised if identifier is not associated with any `BaseRenderer` + """ + if identifier not in _registry.keys(): + msg = f"Identifier {identifier} is not associated with any `BaseRenderer`." + raise KeyError(msg) + return _registry[identifier]() diff --git a/trade_flow/environments/nt_backtest/sub_training_env/rewards.py b/trade_flow/environments/nt_backtest/sub_training_env/rewards.py new file mode 100644 index 0000000..ef4c772 --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/rewards.py @@ -0,0 +1,245 @@ +from abc import abstractmethod + +import numpy as np +import pandas as pd + +from trade_flow.environments.generic import RewardScheme, TradingEnvironment +from trade_flow.feed import Stream, DataFeed + + +class TradeFlowRewardScheme(RewardScheme): + """An abstract base class for reward schemes for the default environment.""" + + def reward(self, env: "TradingEnvironment") -> float: + return self.get_reward(env.action_scheme.portfolio) + + @abstractmethod + def get_reward(self, portfolio) -> float: + """Gets the reward associated with current step of the episode. + + Parameters + ---------- + portfolio : `Portfolio` + The portfolio associated with the `TradeFlowActionScheme`. + + Returns + ------- + float + The reward for the current step of the episode. + """ + raise NotImplementedError() + + +class SimpleProfit(TradeFlowRewardScheme): + """A simple reward scheme that rewards the agent for incremental increases + in net worth. + + Parameters + ---------- + window_size : int + The size of the look back window for computing the reward. + + Attributes + ---------- + window_size : int + The size of the look back window for computing the reward. + """ + + def __init__(self, window_size: int = 1): + self._window_size = self.default("window_size", window_size) + + def get_reward(self, portfolio: "Portfolio") -> float: + """Rewards the agent for incremental increases in net worth over a + sliding window. + + Parameters + ---------- + portfolio : `Portfolio` + The portfolio being used by the environment. + + Returns + ------- + float + The cumulative percentage change in net worth over the previous + `window_size` time steps. + """ + net_worths = [nw["net_worth"] for nw in portfolio.performance.values()] + if len(net_worths) > 1: + return net_worths[-1] / net_worths[-min(len(net_worths), self._window_size + 1)] - 1.0 + else: + return 0.0 + + +class RiskAdjustedReturns(TradeFlowRewardScheme): + """A reward scheme that rewards the agent for increasing its net worth, + while penalizing more volatile strategies. + + Parameters + ---------- + return_algorithm : {'sharpe', 'sortino'}, Default 'sharpe'. + The risk-adjusted return metric to use. + risk_free_rate : float, Default 0. + The risk free rate of returns to use for calculating metrics. + target_returns : float, Default 0 + The target returns per period for use in calculating the sortino ratio. + window_size : int + The size of the look back window for computing the reward. + """ + + def __init__( + self, + return_algorithm: str = "sharpe", + risk_free_rate: float = 0.0, + target_returns: float = 0.0, + window_size: int = 1, + ) -> None: + algorithm = self.default("return_algorithm", return_algorithm) + + assert algorithm in ["sharpe", "sortino"] + + if algorithm == "sharpe": + return_algorithm = self._sharpe_ratio + elif algorithm == "sortino": + return_algorithm = self._sortino_ratio + + self._return_algorithm = return_algorithm + self._risk_free_rate = self.default("risk_free_rate", risk_free_rate) + self._target_returns = self.default("target_returns", target_returns) + self._window_size = self.default("window_size", window_size) + + def _sharpe_ratio(self, returns: "pd.Series") -> float: + """Computes the sharpe ratio for a given series of a returns. + + Parameters + ---------- + returns : `pd.Series` + The returns for the `portfolio`. + + Returns + ------- + float + The sharpe ratio for the given series of a `returns`. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Sharpe_ratio + """ + return (np.mean(returns) - self._risk_free_rate + 1e-9) / (np.std(returns) + 1e-9) + + def _sortino_ratio(self, returns: "pd.Series") -> float: + """Computes the sortino ratio for a given series of a returns. + + Parameters + ---------- + returns : `pd.Series` + The returns for the `portfolio`. + + Returns + ------- + float + The sortino ratio for the given series of a `returns`. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Sortino_ratio + """ + downside_returns = returns.copy() + downside_returns[returns < self._target_returns] = returns**2 + + expected_return = np.mean(returns) + downside_std = np.sqrt(np.std(downside_returns)) + + return (expected_return - self._risk_free_rate + 1e-9) / (downside_std + 1e-9) + + def get_reward(self, portfolio: "Portfolio") -> float: + """Computes the reward corresponding to the selected risk-adjusted return metric. + + Parameters + ---------- + portfolio : `Portfolio` + The current portfolio being used by the environment. + + Returns + ------- + float + The reward corresponding to the selected risk-adjusted return metric. + """ + net_worths = [nw["net_worth"] for nw in portfolio.performance.values()][ + -(self._window_size + 1) : + ] + returns = pd.Series(net_worths).pct_change().dropna() + risk_adjusted_return = self._return_algorithm(returns) + return risk_adjusted_return + + +class PBR(TradeFlowRewardScheme): + """A reward scheme for position-based returns. + + * Let :math:`p_t` denote the price at time t. + * Let :math:`x_t` denote the position at time t. + * Let :math:`R_t` denote the reward at time t. + + Then the reward is defined as, + :math:`R_{t} = (p_{t} - p_{t-1}) \cdot x_{t}`. + + Parameters + ---------- + price : `Stream` + The price stream to use for computing rewards. + """ + + registered_name = "pbr" + + def __init__(self, price: "Stream") -> None: + super().__init__() + self.position = -1 + + r = Stream.sensor(price, lambda p: p.value, dtype="float").diff() + position = Stream.sensor(self, lambda rs: rs.position, dtype="float") + + reward = (position * r).fillna(0).rename("reward") + + self.feed = DataFeed([reward]) + self.feed.compile() + + def on_action(self, action: int) -> None: + self.position = -1 if action == 0 else 1 + + def get_reward(self, portfolio: "Portfolio") -> float: + return self.feed.next()["reward"] + + def reset(self) -> None: + """Resets the `position` and `feed` of the reward scheme.""" + self.position = -1 + self.feed.reset() + + +_registry = { + "simple": SimpleProfit, + "risk-adjusted": RiskAdjustedReturns, + "pbr": PBR, +} + + +def get(identifier: str) -> "TradeFlowRewardScheme": + """Gets the `RewardScheme` that matches with the identifier. + + Parameters + ---------- + identifier : str + The identifier for the `RewardScheme` + + Returns + ------- + `TradeFlowRewardScheme` + The reward scheme associated with the `identifier`. + + Raises + ------ + KeyError: + Raised if identifier is not associated with any `RewardScheme` + """ + if identifier not in _registry.keys(): + msg = f"Identifier {identifier} is not associated with any `RewardScheme`." + raise KeyError(msg) + return _registry[identifier]() diff --git a/trade_flow/environments/nt_backtest/sub_training_env/stoppers.py b/trade_flow/environments/nt_backtest/sub_training_env/stoppers.py new file mode 100644 index 0000000..c17bf3a --- /dev/null +++ b/trade_flow/environments/nt_backtest/sub_training_env/stoppers.py @@ -0,0 +1,32 @@ +from trade_flow.environments.generic import Stopper, TradingEnvironment + + +class MaxLossStopper(Stopper): + """A stopper that stops an episode if the portfolio has lost a particular + percentage of its wealth. + + Parameters + ---------- + max_allowed_loss : float + The maximum percentage of initial funds that is willing to + be lost before stopping the episode. + + Attributes + ---------- + max_allowed_loss : float + The maximum percentage of initial funds that is willing to + be lost before stopping the episode. + + Notes + ----- + This stopper also stops if it has reached the end of the observation feed. + """ + + def __init__(self, max_allowed_loss: float): + super().__init__() + self.max_allowed_loss = max_allowed_loss + + def stop(self, env: "TradingEnvironment") -> bool: + c1 = env.action_scheme.portfolio.profit_loss > self.max_allowed_loss + c2 = not env.observer.has_next() + return c1 or c2 diff --git a/trade_flow/environments/nt_backtest/utils.py b/trade_flow/environments/nt_backtest/utils.py index 0f38ad0..5b514ea 100644 --- a/trade_flow/environments/nt_backtest/utils.py +++ b/trade_flow/environments/nt_backtest/utils.py @@ -6,6 +6,9 @@ from nautilus_trader.model.enums import AggregationSource from nautilus_trader.model.identifiers import InstrumentId +from trade_flow.environments.default.oms.portfolio import Portfolio +from trade_flow.environments.generic.environment import TradingEnvironment + def make_bar_type(instrument_id: InstrumentId, bar_spec) -> BarType: return BarType( @@ -24,7 +27,7 @@ def one(iterable): return iterable[0] -def bars_to_dataframe( +def pair_bars_to_dataframe( source_id: str, source_bars: List[Bar], target_id: str, target_bars: List[Bar] ) -> pd.DataFrame: def _bars_to_frame(bars, instrument_id): @@ -37,6 +40,16 @@ def _bars_to_frame(bars, instrument_id): return data.dropna() +def bars_to_dataframe(instrument_id: str, instrument_bars: List[Bar]) -> pd.DataFrame: + def _bars_to_frame(bars, instrument_id): + df = pd.DataFrame([t.to_dict(t) for t in bars]).astype({"close": float}) + return df.assign(instrument_id=instrument_id).set_index(["instrument_id", "ts_init"]) + + instrument_df = _bars_to_frame(bars=instrument_bars, instrument_id=instrument_id) + data = pd.concat([instrument_df])["close"].unstack(0).sort_index().fillna(method="ffill") + return data.dropna() + + def human_readable_duration(ns: float): from dateutil.relativedelta import relativedelta # type: ignore From a672ac3d97af1027aeef2e98bc33e4fc99daa2b8 Mon Sep 17 00:00:00 2001 From: Ojietohamen Samuel Date: Tue, 8 Oct 2024 17:07:26 +0100 Subject: [PATCH 04/12] feat: add base application component --- infrastructure/application/.eslintrc.json | 3 + infrastructure/application/.gitignore | 36 + infrastructure/application/README.md | 252 + .../admin/components/authentication.tsx | 100 + .../app/(admin)/admin/components/contents.tsx | 75 + .../application/app/(admin)/admin/page.tsx | 17 + .../(routes)/sign-in/[[...sign-in]]/page.tsx | 5 + .../(routes)/sign-up/[[...sign-up]]/page.tsx | 5 + .../application/app/(auth)/layout.tsx | 36 + .../application/app/(landing-page)/page.tsx | 13 + .../(routes)/chat/components/BotAvatar.tsx | 15 + .../(routes)/chat/components/ChatContent.tsx | 95 + .../(routes)/chat/components/ChatMessage.tsx | 41 + .../(routes)/chat/components/ChatMessages.tsx | 59 + .../(routes)/chat/components/SideContent.tsx | 17 + .../(routes)/chat/components/UserAvatar.tsx | 25 + .../(routes)/chat/components/constants.ts | 7 + .../app/(root)/(routes)/chat/page.tsx | 26 + .../components/accordion-container.tsx | 51 + .../components/accordion/doughnut-chart.tsx | 84 + .../components/accordion/portfolio-item.tsx | 68 + .../components/accordion/watchlist.tsx | 63 + .../components/account-container.tsx | 56 + .../components/account/account-card.tsx | 61 + .../components/account/baseline-chart.tsx | 0 .../components/account/line-chart.tsx | 58 + .../dashboard/components/bank-container.tsx | 15 + .../dashboard/components/bank/add-card.tsx | 187 + .../components/bank/card-content.tsx | 40 + .../components/bank/color-selector.tsx | 44 + .../dashboard/components/bank/deposit.tsx | 241 + .../components/bank/dropdown-content.tsx | 103 + .../dashboard/components/bank/remove-card.tsx | 184 + .../dashboard/components/bank/switch-card.tsx | 33 + .../dashboard/components/bank/view-cards.tsx | 56 + .../dashboard/components/bank/withdraw.tsx | 246 + .../dashboard/components/bank/wrapper.tsx | 47 + .../(routes)/dashboard/components/index.ts | 36 + .../dashboard/components/table-container.tsx | 38 + .../components/table/column-header.tsx | 72 + .../dashboard/components/table/columns.tsx | 92 + .../table/data-table-pagination.tsx | 98 + .../dashboard/components/table/data-table.tsx | 172 + .../dashboard/components/table/greetings.tsx | 18 + .../app/(root)/(routes)/dashboard/page.tsx | 58 + .../components/company-profile/back-btn.tsx | 23 + .../company-profile/company-details.tsx | 222 + .../company-profile/company-profile.tsx | 125 + .../company-profile/employee-card.tsx | 20 + .../company-profile/executive-card.tsx | 30 + .../market/components/market-container.tsx | 85 + .../market/components/products/bar-chart.tsx | 55 + .../components/products/featured-product.tsx | 85 + .../market/components/products/heart.tsx | 114 + .../components/products/product-dialog.tsx | 48 + .../components/products/progress-bar.tsx | 73 + .../market/components/table/column-header.tsx | 71 + .../market/components/table/columns.tsx | 124 + .../market/components/table/data-table.tsx | 203 + .../market/components/table/search-input.tsx | 45 + .../market/components/table/stock-list.tsx | 198 + .../components/table/table-container.tsx | 26 + .../components/table/table-pagination.tsx | 96 + .../transaction/buy-transaction.tsx | 224 + .../transaction/sell-transaction.tsx | 229 + .../market/components/transaction/wrapper.tsx | 39 + .../app/(root)/(routes)/market/page.tsx | 27 + .../application/app/(root)/layout.tsx | 21 + .../application/app/api/account/route.ts | 33 + .../app/api/admin/authenticate/route.ts | 13 + .../application/app/api/admin/route.ts | 31 + .../application/app/api/card/add/route.ts | 27 + .../application/app/api/card/deposit/route.ts | 68 + .../application/app/api/card/remove/route.ts | 20 + .../application/app/api/card/route.ts | 21 + .../app/api/card/withdraw/route.ts | 67 + .../application/app/api/chat/route.ts | 52 + .../application/app/api/portfolio/route.ts | 30 + .../app/api/transaction/buy/route.ts | 99 + .../application/app/api/transaction/route.ts | 33 + .../app/api/transaction/sell/route.ts | 109 + .../application/app/api/watchlist/route.ts | 67 + infrastructure/application/app/favicon.ico | Bin 0 -> 3230 bytes infrastructure/application/app/globals.css | 100 + infrastructure/application/app/layout.tsx | 47 + infrastructure/application/components.json | 17 + .../components/aceternity-ui/3d-pin.tsx | 165 + .../aceternity-ui/moving-border.tsx | 139 + .../application/components/app/hero.tsx | 53 + .../components/app/home-navbar.tsx | 43 + .../components/app/image-component.tsx | 13 + .../components/app/image-fallback.tsx | 30 + .../application/components/app/navbar.tsx | 38 + .../application/components/app/sidebar.tsx | 72 + .../components/app/theme-provider.tsx | 9 + .../components/app/theme-switch.tsx | 44 + .../components/providers/providers.tsx | 15 + .../components/providers/toast-provider.tsx | 5 + .../components/shadcn-ui/accordion.tsx | 58 + .../components/shadcn-ui/button.tsx | 58 + .../application/components/shadcn-ui/card.tsx | 79 + .../components/shadcn-ui/checkbox.tsx | 30 + .../components/shadcn-ui/collapsible.tsx | 11 + .../components/shadcn-ui/dialog.tsx | 122 + .../components/shadcn-ui/dropdown-menu.tsx | 200 + .../application/components/shadcn-ui/form.tsx | 176 + .../components/shadcn-ui/input.tsx | 25 + .../components/shadcn-ui/label.tsx | 26 + .../components/shadcn-ui/pagination.tsx | 117 + .../components/shadcn-ui/progress.tsx | 28 + .../components/shadcn-ui/scroll-area.tsx | 48 + .../components/shadcn-ui/select.tsx | 160 + .../components/shadcn-ui/separator.tsx | 31 + .../components/shadcn-ui/sheet.tsx | 139 + .../components/shadcn-ui/skeleton.tsx | 15 + .../components/shadcn-ui/switch.tsx | 29 + .../components/shadcn-ui/table.tsx | 117 + .../components/shadcn-ui/toast.tsx | 127 + .../components/shadcn-ui/toaster.tsx | 44 + .../components/shadcn-ui/use-toast.ts | 192 + infrastructure/application/constants/index.ts | 38 + .../application/doc/architecture.d2 | 90 + .../application/doc/architecture.svg | 126 + infrastructure/application/doc/class.d2 | 177 + infrastructure/application/doc/class.svg | 102 + .../application/generate_prj_view.py | 42 + .../application/hooks/use-account.ts | 11 + .../application/hooks/use-animation.ts | 15 + infrastructure/application/hooks/use-color.ts | 11 + infrastructure/application/hooks/use-order.ts | 15 + .../application/hooks/use-ticker.ts | 15 + infrastructure/application/middleware.ts | 12 + infrastructure/application/next.config.js | 20 + infrastructure/application/package-lock.json | 10526 ++++++++++++++++ infrastructure/application/package.json | 98 + .../application/prisma/schema.prisma | 134 + .../application/public/avatars/ava1.png | Bin 0 -> 14448 bytes .../application/public/avatars/ava10.png | Bin 0 -> 15833 bytes .../application/public/avatars/ava2.png | Bin 0 -> 16343 bytes .../application/public/avatars/ava3.png | Bin 0 -> 16242 bytes .../application/public/avatars/ava4.png | Bin 0 -> 15262 bytes .../application/public/avatars/ava5.png | Bin 0 -> 16493 bytes .../application/public/avatars/ava6.png | Bin 0 -> 14962 bytes .../application/public/avatars/ava7.png | Bin 0 -> 15784 bytes .../application/public/avatars/ava8.png | Bin 0 -> 17683 bytes .../application/public/avatars/ava9.png | Bin 0 -> 14706 bytes .../public/employees/boeing-ceo.jpg | Bin 0 -> 88107 bytes .../application/public/employees/ceo-fb.webp | Bin 0 -> 11114 bytes .../public/employees/contract.webp | Bin 0 -> 9692 bytes .../application/public/employees/jnj-ceo.webp | Bin 0 -> 670286 bytes .../application/public/employees/lly-ceo.webp | Bin 0 -> 75356 bytes .../application/public/employees/mrk-ceo.webp | Bin 0 -> 7604558 bytes .../public/employees/nvda-ceo.webp | Bin 0 -> 10786 bytes .../public/employees/orcl-ceo.webp | Bin 0 -> 2276110 bytes .../application/public/employees/tsla-ceo.png | Bin 0 -> 393184 bytes .../public/landing-page/landing-animation.gif | Bin 0 -> 874144 bytes .../public/landing-page/landing-page.png | Bin 0 -> 151892 bytes .../landing-page/landingpage-animation2.gif | Bin 0 -> 3914839 bytes .../application/public/landing-page/logo.webp | Bin 0 -> 26744 bytes .../application/public/landing-page/logo2.png | Bin 0 -> 40242 bytes .../application/public/logos/aapl.svg | 2 + .../application/public/logos/abnb.svg | 7 + .../application/public/logos/amzn.svg | 2 + .../application/public/logos/ba.svg | 1 + .../application/public/logos/cisco.svg | 1 + .../application/public/logos/cola.svg | 9 + .../application/public/logos/dummy-logo.webp | Bin 0 -> 36874 bytes .../public/logos/dynamitetrade.webp | Bin 0 -> 6808 bytes .../application/public/logos/ebay.svg | 1 + .../application/public/logos/googl.svg | 2 + .../application/public/logos/ibm.svg | 1 + .../application/public/logos/jd.svg | 1 + .../application/public/logos/ma.svg | 16 + .../application/public/logos/meta.svg | 1 + .../application/public/logos/msft.svg | 1 + .../application/public/logos/nike.svg | 1 + .../application/public/logos/nvda.svg | 2 + .../application/public/logos/pfizer.svg | 1 + .../application/public/logos/pypl.svg | 15 + .../application/public/logos/qcom.svg | 1 + .../application/public/logos/salesforce.svg | 1 + .../application/public/logos/shopify.svg | 1 + .../application/public/logos/sofi.svg | 38 + .../application/public/logos/spotify.svg | 1 + .../application/public/logos/tsla.svg | 1 + .../application/public/logos/uber.svg | 1 + .../application/public/logos/ups.svg | 1 + .../application/public/logos/wmt.svg | 1 + .../application/public/logos/xom.svg | 1 + .../application/public/products/aapl.webp | Bin 0 -> 5108 bytes .../application/public/products/abnb.webp | Bin 0 -> 69952 bytes .../application/public/products/amd.webp | Bin 0 -> 14338 bytes .../application/public/products/amzn.webp | Bin 0 -> 23442 bytes .../application/public/products/ba.webp | Bin 0 -> 22140 bytes .../application/public/products/cola.webp | Bin 0 -> 32534 bytes .../public/products/dummy-product.webp | Bin 0 -> 169574 bytes .../application/public/products/ebay.webp | Bin 0 -> 16096 bytes .../application/public/products/googl.webp | Bin 0 -> 15290 bytes .../application/public/products/ibm.webp | Bin 0 -> 33448 bytes .../application/public/products/jd.webp | Bin 0 -> 49800 bytes .../application/public/products/ma.webp | Bin 0 -> 5920 bytes .../application/public/products/meta.webp | Bin 0 -> 23462 bytes .../application/public/products/msft.webp | Bin 0 -> 11678 bytes .../application/public/products/nvda.webp | Bin 0 -> 50314 bytes .../application/public/products/pypl.webp | Bin 0 -> 5950 bytes .../application/public/products/sofi.webp | Bin 0 -> 31154 bytes .../application/public/products/tsla.webp | Bin 0 -> 45448 bytes .../application/public/products/visa.webp | Bin 0 -> 66888 bytes .../application/public/products/wmt.webp | Bin 0 -> 25552 bytes .../application/scripts/checkDuplicates.ts | 39 + infrastructure/application/scripts/index.ts | 16 + .../application/scripts/seedCompany.ts | 82 + .../application/scripts/seedImage.ts | 65 + .../application/scripts/seedMessage.ts | 17 + .../application/scripts/seedTicker.ts | 103 + infrastructure/application/tailwind.config.js | 101 + infrastructure/application/tailwind.config.ts | 20 + infrastructure/application/tsconfig.json | 27 + 218 files changed, 20536 insertions(+) create mode 100644 infrastructure/application/.eslintrc.json create mode 100644 infrastructure/application/.gitignore create mode 100644 infrastructure/application/README.md create mode 100644 infrastructure/application/app/(admin)/admin/components/authentication.tsx create mode 100644 infrastructure/application/app/(admin)/admin/components/contents.tsx create mode 100644 infrastructure/application/app/(admin)/admin/page.tsx create mode 100644 infrastructure/application/app/(auth)/(routes)/sign-in/[[...sign-in]]/page.tsx create mode 100644 infrastructure/application/app/(auth)/(routes)/sign-up/[[...sign-up]]/page.tsx create mode 100644 infrastructure/application/app/(auth)/layout.tsx create mode 100644 infrastructure/application/app/(landing-page)/page.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/BotAvatar.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/ChatContent.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/ChatMessage.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/ChatMessages.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/SideContent.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/UserAvatar.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/chat/components/constants.ts create mode 100644 infrastructure/application/app/(root)/(routes)/chat/page.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/accordion-container.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/doughnut-chart.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/portfolio-item.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/watchlist.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/account-container.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/account/account-card.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/account/baseline-chart.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/account/line-chart.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank-container.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/add-card.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/card-content.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/color-selector.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/deposit.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/dropdown-content.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/remove-card.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/switch-card.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/view-cards.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/withdraw.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/bank/wrapper.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/index.ts create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/table-container.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/table/column-header.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/table/columns.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table-pagination.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/components/table/greetings.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/dashboard/page.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/company-profile/back-btn.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-details.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-profile.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/company-profile/employee-card.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/company-profile/executive-card.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/market-container.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/products/bar-chart.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/products/featured-product.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/products/heart.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/products/product-dialog.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/products/progress-bar.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/column-header.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/columns.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/data-table.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/search-input.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/stock-list.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/table-container.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/table/table-pagination.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/transaction/buy-transaction.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/transaction/sell-transaction.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/components/transaction/wrapper.tsx create mode 100644 infrastructure/application/app/(root)/(routes)/market/page.tsx create mode 100644 infrastructure/application/app/(root)/layout.tsx create mode 100644 infrastructure/application/app/api/account/route.ts create mode 100644 infrastructure/application/app/api/admin/authenticate/route.ts create mode 100644 infrastructure/application/app/api/admin/route.ts create mode 100644 infrastructure/application/app/api/card/add/route.ts create mode 100644 infrastructure/application/app/api/card/deposit/route.ts create mode 100644 infrastructure/application/app/api/card/remove/route.ts create mode 100644 infrastructure/application/app/api/card/route.ts create mode 100644 infrastructure/application/app/api/card/withdraw/route.ts create mode 100644 infrastructure/application/app/api/chat/route.ts create mode 100644 infrastructure/application/app/api/portfolio/route.ts create mode 100644 infrastructure/application/app/api/transaction/buy/route.ts create mode 100644 infrastructure/application/app/api/transaction/route.ts create mode 100644 infrastructure/application/app/api/transaction/sell/route.ts create mode 100644 infrastructure/application/app/api/watchlist/route.ts create mode 100644 infrastructure/application/app/favicon.ico create mode 100644 infrastructure/application/app/globals.css create mode 100644 infrastructure/application/app/layout.tsx create mode 100644 infrastructure/application/components.json create mode 100644 infrastructure/application/components/aceternity-ui/3d-pin.tsx create mode 100644 infrastructure/application/components/aceternity-ui/moving-border.tsx create mode 100644 infrastructure/application/components/app/hero.tsx create mode 100644 infrastructure/application/components/app/home-navbar.tsx create mode 100644 infrastructure/application/components/app/image-component.tsx create mode 100644 infrastructure/application/components/app/image-fallback.tsx create mode 100644 infrastructure/application/components/app/navbar.tsx create mode 100644 infrastructure/application/components/app/sidebar.tsx create mode 100644 infrastructure/application/components/app/theme-provider.tsx create mode 100644 infrastructure/application/components/app/theme-switch.tsx create mode 100644 infrastructure/application/components/providers/providers.tsx create mode 100644 infrastructure/application/components/providers/toast-provider.tsx create mode 100644 infrastructure/application/components/shadcn-ui/accordion.tsx create mode 100644 infrastructure/application/components/shadcn-ui/button.tsx create mode 100644 infrastructure/application/components/shadcn-ui/card.tsx create mode 100644 infrastructure/application/components/shadcn-ui/checkbox.tsx create mode 100644 infrastructure/application/components/shadcn-ui/collapsible.tsx create mode 100644 infrastructure/application/components/shadcn-ui/dialog.tsx create mode 100644 infrastructure/application/components/shadcn-ui/dropdown-menu.tsx create mode 100644 infrastructure/application/components/shadcn-ui/form.tsx create mode 100644 infrastructure/application/components/shadcn-ui/input.tsx create mode 100644 infrastructure/application/components/shadcn-ui/label.tsx create mode 100644 infrastructure/application/components/shadcn-ui/pagination.tsx create mode 100644 infrastructure/application/components/shadcn-ui/progress.tsx create mode 100644 infrastructure/application/components/shadcn-ui/scroll-area.tsx create mode 100644 infrastructure/application/components/shadcn-ui/select.tsx create mode 100644 infrastructure/application/components/shadcn-ui/separator.tsx create mode 100644 infrastructure/application/components/shadcn-ui/sheet.tsx create mode 100644 infrastructure/application/components/shadcn-ui/skeleton.tsx create mode 100644 infrastructure/application/components/shadcn-ui/switch.tsx create mode 100644 infrastructure/application/components/shadcn-ui/table.tsx create mode 100644 infrastructure/application/components/shadcn-ui/toast.tsx create mode 100644 infrastructure/application/components/shadcn-ui/toaster.tsx create mode 100644 infrastructure/application/components/shadcn-ui/use-toast.ts create mode 100644 infrastructure/application/constants/index.ts create mode 100644 infrastructure/application/doc/architecture.d2 create mode 100644 infrastructure/application/doc/architecture.svg create mode 100644 infrastructure/application/doc/class.d2 create mode 100644 infrastructure/application/doc/class.svg create mode 100644 infrastructure/application/generate_prj_view.py create mode 100644 infrastructure/application/hooks/use-account.ts create mode 100644 infrastructure/application/hooks/use-animation.ts create mode 100644 infrastructure/application/hooks/use-color.ts create mode 100644 infrastructure/application/hooks/use-order.ts create mode 100644 infrastructure/application/hooks/use-ticker.ts create mode 100644 infrastructure/application/middleware.ts create mode 100644 infrastructure/application/next.config.js create mode 100644 infrastructure/application/package-lock.json create mode 100644 infrastructure/application/package.json create mode 100644 infrastructure/application/prisma/schema.prisma create mode 100644 infrastructure/application/public/avatars/ava1.png create mode 100644 infrastructure/application/public/avatars/ava10.png create mode 100644 infrastructure/application/public/avatars/ava2.png create mode 100644 infrastructure/application/public/avatars/ava3.png create mode 100644 infrastructure/application/public/avatars/ava4.png create mode 100644 infrastructure/application/public/avatars/ava5.png create mode 100644 infrastructure/application/public/avatars/ava6.png create mode 100644 infrastructure/application/public/avatars/ava7.png create mode 100644 infrastructure/application/public/avatars/ava8.png create mode 100644 infrastructure/application/public/avatars/ava9.png create mode 100644 infrastructure/application/public/employees/boeing-ceo.jpg create mode 100644 infrastructure/application/public/employees/ceo-fb.webp create mode 100644 infrastructure/application/public/employees/contract.webp create mode 100644 infrastructure/application/public/employees/jnj-ceo.webp create mode 100644 infrastructure/application/public/employees/lly-ceo.webp create mode 100644 infrastructure/application/public/employees/mrk-ceo.webp create mode 100644 infrastructure/application/public/employees/nvda-ceo.webp create mode 100644 infrastructure/application/public/employees/orcl-ceo.webp create mode 100644 infrastructure/application/public/employees/tsla-ceo.png create mode 100644 infrastructure/application/public/landing-page/landing-animation.gif create mode 100644 infrastructure/application/public/landing-page/landing-page.png create mode 100644 infrastructure/application/public/landing-page/landingpage-animation2.gif create mode 100644 infrastructure/application/public/landing-page/logo.webp create mode 100644 infrastructure/application/public/landing-page/logo2.png create mode 100644 infrastructure/application/public/logos/aapl.svg create mode 100644 infrastructure/application/public/logos/abnb.svg create mode 100644 infrastructure/application/public/logos/amzn.svg create mode 100644 infrastructure/application/public/logos/ba.svg create mode 100644 infrastructure/application/public/logos/cisco.svg create mode 100644 infrastructure/application/public/logos/cola.svg create mode 100644 infrastructure/application/public/logos/dummy-logo.webp create mode 100644 infrastructure/application/public/logos/dynamitetrade.webp create mode 100644 infrastructure/application/public/logos/ebay.svg create mode 100644 infrastructure/application/public/logos/googl.svg create mode 100644 infrastructure/application/public/logos/ibm.svg create mode 100644 infrastructure/application/public/logos/jd.svg create mode 100644 infrastructure/application/public/logos/ma.svg create mode 100644 infrastructure/application/public/logos/meta.svg create mode 100644 infrastructure/application/public/logos/msft.svg create mode 100644 infrastructure/application/public/logos/nike.svg create mode 100644 infrastructure/application/public/logos/nvda.svg create mode 100644 infrastructure/application/public/logos/pfizer.svg create mode 100644 infrastructure/application/public/logos/pypl.svg create mode 100644 infrastructure/application/public/logos/qcom.svg create mode 100644 infrastructure/application/public/logos/salesforce.svg create mode 100644 infrastructure/application/public/logos/shopify.svg create mode 100644 infrastructure/application/public/logos/sofi.svg create mode 100644 infrastructure/application/public/logos/spotify.svg create mode 100644 infrastructure/application/public/logos/tsla.svg create mode 100644 infrastructure/application/public/logos/uber.svg create mode 100644 infrastructure/application/public/logos/ups.svg create mode 100644 infrastructure/application/public/logos/wmt.svg create mode 100644 infrastructure/application/public/logos/xom.svg create mode 100644 infrastructure/application/public/products/aapl.webp create mode 100644 infrastructure/application/public/products/abnb.webp create mode 100644 infrastructure/application/public/products/amd.webp create mode 100644 infrastructure/application/public/products/amzn.webp create mode 100644 infrastructure/application/public/products/ba.webp create mode 100644 infrastructure/application/public/products/cola.webp create mode 100644 infrastructure/application/public/products/dummy-product.webp create mode 100644 infrastructure/application/public/products/ebay.webp create mode 100644 infrastructure/application/public/products/googl.webp create mode 100644 infrastructure/application/public/products/ibm.webp create mode 100644 infrastructure/application/public/products/jd.webp create mode 100644 infrastructure/application/public/products/ma.webp create mode 100644 infrastructure/application/public/products/meta.webp create mode 100644 infrastructure/application/public/products/msft.webp create mode 100644 infrastructure/application/public/products/nvda.webp create mode 100644 infrastructure/application/public/products/pypl.webp create mode 100644 infrastructure/application/public/products/sofi.webp create mode 100644 infrastructure/application/public/products/tsla.webp create mode 100644 infrastructure/application/public/products/visa.webp create mode 100644 infrastructure/application/public/products/wmt.webp create mode 100644 infrastructure/application/scripts/checkDuplicates.ts create mode 100644 infrastructure/application/scripts/index.ts create mode 100644 infrastructure/application/scripts/seedCompany.ts create mode 100644 infrastructure/application/scripts/seedImage.ts create mode 100644 infrastructure/application/scripts/seedMessage.ts create mode 100644 infrastructure/application/scripts/seedTicker.ts create mode 100644 infrastructure/application/tailwind.config.js create mode 100644 infrastructure/application/tailwind.config.ts create mode 100644 infrastructure/application/tsconfig.json diff --git a/infrastructure/application/.eslintrc.json b/infrastructure/application/.eslintrc.json new file mode 100644 index 0000000..bffb357 --- /dev/null +++ b/infrastructure/application/.eslintrc.json @@ -0,0 +1,3 @@ +{ + "extends": "next/core-web-vitals" +} diff --git a/infrastructure/application/.gitignore b/infrastructure/application/.gitignore new file mode 100644 index 0000000..7e1e047 --- /dev/null +++ b/infrastructure/application/.gitignore @@ -0,0 +1,36 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js +.yarn/install-state.gz + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env*.local +.env +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts diff --git a/infrastructure/application/README.md b/infrastructure/application/README.md new file mode 100644 index 0000000..d8e934b --- /dev/null +++ b/infrastructure/application/README.md @@ -0,0 +1,252 @@ +# :point_right: Introduction + +DynamiteTrade is an opensource financial project aiming to provide a clean and snappy user interface to view stock information of companies in (soft) real time. +The project is written in JavaScript/TypeScript leveraging full stack framework NextJS and deploying on Vercel hosting platform. + +Here is the site: [https://dynamitetrade.vercel.app/](https://dynamitetrade.vercel.app/) + + + +## :zap: Tech Stack + +
+Client + +* Typescript +* Tailwind CSS +* Clerk +* Shadcn-UI +* Aceternity-UI +* Chart.JS + +
+ +
+Server + +* Typescript +* NextJS +* Tanstack Query +* ChatGPT-3,5 Turbo + +
+ +
+Database + +* Postgresql +* Supabass +* Prisma + +
+ +
+Hosting + +* Vercel + +
+ +## :pushpin: Features + +- ChatGPT-3.5 Bot +- Light/dark mode toggle +- Buy/sell stocks +- Add/remove stocks from watchlist +- View transaction records +- Add/remove customizable cards +- Deposit/withdraw from trading account +- Authentication +- Search stocks +- View stock prices & company history + +## :key: Environment Variables + +To run this project, you will need to add the following environment variables to your .env file + +`YAHOO_FINANCE_STOCK_SUMMARY` + +`NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY` + +`CLERK_SECRET_KEY` + +`NEXT_PUBLIC_CLERK_SIGN_IN_URL` + +`NEXT_PUBLIC_CLERK_SIGN_UP_URL` + +`NEXT_PUBLIC_CLERK_AFTER_SIGN_IN_URL` + +`NEXT_PUBLIC_CLERK_AFTER_SIGN_UP_URL` + +`DATABASE_URL` + +`DIRECT_URL` + +# :hammer: Build + +This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). + +## Getting Started + +First, run the development server: + +```bash +npm run dev +# or +yarn dev +# or +pnpm dev +# or +bun dev +``` + +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. + +You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. + +This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font. + +## Learn More + +To learn more about Next.js, take a look at the following resources: + +- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. +- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. + +## Deploy on Vercel + +The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. + +Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details. + +## Project Layout + +``` +└──./ + ├──app/ + │ ├──(auth)/ + │ │ ├──(routes)/ + │ │ │ ├──sign-in/ + │ │ │ │ └──[[...sign-in]]/ + │ │ │ └──sign-up/ + │ │ │ └──[[...sign-up]]/ + │ ├──(landing-page)/ + │ ├──(root)/ + │ │ ├──(routes)/ + │ │ │ ├──chat/ + │ │ │ │ ├──components/ + │ │ │ ├──dashboard/ + │ │ │ │ ├──components/ + │ │ │ │ │ ├──accordion/ + │ │ │ │ │ ├──account/ + │ │ │ │ │ ├──bank/ + │ │ │ │ │ ├──not-use/ + │ │ │ │ │ ├──table/ + │ │ │ └──market/ + │ │ │ ├──components/ + │ │ │ │ ├──company-profile/ + │ │ │ │ ├──company-table/ + │ │ │ │ ├──products/ + │ │ │ │ └──transaction/ + │ ├──api/ + │ │ ├──account/ + │ │ ├──chat/ + │ │ ├──dashboard-watchlist/ + │ │ ├──market-watchlist/ + │ │ └──portfolio/ + ├──components/ + │ ├──aceternity-ui/ + │ ├──shadcn-ui/ + │ └──ui/ + ├──constants/ + ├──hooks/ + ├──lib/ + ├──prisma/ + ├──providers/ + ├──public/ + │ ├──avatars/ + │ ├──employees/ + │ ├──logos/ + │ ├──products/ + ├──scripts/ + ├──types/ + └──utils/ +``` + + +## :green_book: API Reference + +#### Get account + +```http + GET /api/account +``` + +| Parameter | Type | Description | +| :-------- | :------- | :------------------------- | +| `N/A` | `Object` | **Get** user account from the database | + +#### Get portfolio + +```http + GET /api/portfolio +``` + +| Parameter | Type | Description | +| :-------- | :------- | :------------------------- | +| `N/A` | `Object` | **Get** user portfolio from the database | + +#### Update portfolio + +```http + PATCH /api/portfolio +``` + +| Parameter | Type | Description | +| :-------- | :------- | :-------------------------------- | +| `N/A` | `void` | **Update** the value of user's portfolio| + +#### Get transaction + +```http + GET /api/transaction +``` + +| Parameter | Type | Description | +| :-------- | :------- | :------------------------- | +| `N/A` | `Object[]` | **Get** a list of transactions from the database | + +#### Buy stock + +```http + POST /api/transaction/buy +``` + +| Parameter | Type | Description | +| :-------- | :------- | :------------------------- | +| `N/A` | `void` | **Execute** the 'buy' function | + +#### Sell stock + +```http + POST /api/transaction/sell +``` + +| Parameter | Type | Description | +| :-------- | :------- | :------------------------- | +| `N/A` | `void` | **Execute** the 'sell' function | + + + +## ER Diagram + + + +## Architecture + + + +## Classes + + + diff --git a/infrastructure/application/app/(admin)/admin/components/authentication.tsx b/infrastructure/application/app/(admin)/admin/components/authentication.tsx new file mode 100644 index 0000000..ca25e31 --- /dev/null +++ b/infrastructure/application/app/(admin)/admin/components/authentication.tsx @@ -0,0 +1,100 @@ +"use client" + +import { zodResolver } from "@hookform/resolvers/zod" +import { useForm } from "react-hook-form" +import { z } from "zod" +import { useQueryClient, useMutation } from "@tanstack/react-query" +import { Button } from "@/components/shadcn-ui/button" +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/shadcn-ui/form" +import { Input } from "@/components/shadcn-ui/input" +import axios from "axios" +import { useToast } from "@/components/shadcn-ui/use-toast" +import toast from "react-hot-toast" +import { useRouter } from "next/navigation" + +const formSchema = z.object({ + username: z.string().min(2, { + message: "Username must be at least 2 characters.", + }), +}) + +interface AuthenticationProps { + isSuccess: boolean, + setIsSuccess: (isSuccess: boolean) => void +} + +export function Authentication({ isSuccess, setIsSuccess }: AuthenticationProps) { + const { toast: shadcnToast } = useToast() + const router = useRouter() + const onSubmit = async (values: z.infer) => { + try { + authenticate(values) + // clear user input + } catch (error) { + console.log("Error submitting the form") + } finally { + //router.refresh() + } + } + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + username: "", + }, + }) + + const queryClient = useQueryClient(); + const { mutate: authenticate, isPending } = useMutation({ + mutationFn: (values: z.infer) => { + return axios.post("/api/admin/authenticate", { values }) + }, + onError: (error) => { + shadcnToast({ + variant: "destructive", + title: "Wrong key.", + }) + console.log(error) + }, + onSuccess: () => { + toast.success("Login Successfully") + setIsSuccess(!isSuccess); + } + }) + + + return ( +
+ +
+ + ( + + Admin Code + + + + + This is your public display name. + + + + )} + /> + + + +
+ ) +} diff --git a/infrastructure/application/app/(admin)/admin/components/contents.tsx b/infrastructure/application/app/(admin)/admin/components/contents.tsx new file mode 100644 index 0000000..5cd6882 --- /dev/null +++ b/infrastructure/application/app/(admin)/admin/components/contents.tsx @@ -0,0 +1,75 @@ +import prismadb from "@/lib/prismadb" +import { Account } from "@prisma/client" +import { useQuery } from "@tanstack/react-query" +import axios from "axios" +import Link from "next/link" + +export const Contents = () => { + const { data, isLoading } = useQuery({ + queryKey: ['getAllData'], + queryFn: async () => { + const response = await axios.get("/api/admin") + return response.data + } + }) + console.log(data) + if (isLoading) { + return
loading...
+ } + return ( +
+ + Home + +

Number of active users: {data.length}

+
+

Users

+
+ {data.map((item: Account) => ( +
+
+

UserId

+

{item.id}

+
+
+

Account balance

+

${item.accountBalance}

+
+
+

Account value

+

${item.accountValue}

+
+
+ ))} +
+
+ + {data.map((item: Account, index: number) => ( +
+

User {item.id}

+

Portfolio: {item.portfolio.companies.length} stocks

+
+ {item.portfolio.companies.map((item2: Account) => ( +
+
+

Symbol

+

{item2.symbol}

+
+
+

Price

+

${item2.price}

+
+ +
+ ))} +
+
+ ))} + + +
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(admin)/admin/page.tsx b/infrastructure/application/app/(admin)/admin/page.tsx new file mode 100644 index 0000000..e427fed --- /dev/null +++ b/infrastructure/application/app/(admin)/admin/page.tsx @@ -0,0 +1,17 @@ +"use client" +import { Authentication } from "./components/authentication" +import { useState } from "react" +import { Contents } from "./components/contents" +const AdminPage = () => { + const [isSuccess, setIsSuccess] = useState(false) + return ( + <> + {!isSuccess ? + : + + } + + ) +} + +export default AdminPage \ No newline at end of file diff --git a/infrastructure/application/app/(auth)/(routes)/sign-in/[[...sign-in]]/page.tsx b/infrastructure/application/app/(auth)/(routes)/sign-in/[[...sign-in]]/page.tsx new file mode 100644 index 0000000..5ee3fe3 --- /dev/null +++ b/infrastructure/application/app/(auth)/(routes)/sign-in/[[...sign-in]]/page.tsx @@ -0,0 +1,5 @@ +import { SignIn } from "@clerk/nextjs"; + +export default function Page() { + return ; +} \ No newline at end of file diff --git a/infrastructure/application/app/(auth)/(routes)/sign-up/[[...sign-up]]/page.tsx b/infrastructure/application/app/(auth)/(routes)/sign-up/[[...sign-up]]/page.tsx new file mode 100644 index 0000000..2743945 --- /dev/null +++ b/infrastructure/application/app/(auth)/(routes)/sign-up/[[...sign-up]]/page.tsx @@ -0,0 +1,5 @@ +import { SignUp } from "@clerk/nextjs"; + +export default function Page() { + return ; +} diff --git a/infrastructure/application/app/(auth)/layout.tsx b/infrastructure/application/app/(auth)/layout.tsx new file mode 100644 index 0000000..b49ba0f --- /dev/null +++ b/infrastructure/application/app/(auth)/layout.tsx @@ -0,0 +1,36 @@ +import Image from 'next/image' +import Link from 'next/link' +import React from 'react' +//This is a layout ui page when open sign-in/sign-up page +const layout = ({ children }: { children: React.ReactNode }) => { + return ( +
+
+ logo +
+

DynamiteTrade

+

Trade Explosively like never before

+
+
+ +
+
+ {children} + + Sign in as Admin + +
+
+
+ ) +} + +export default layout \ No newline at end of file diff --git a/infrastructure/application/app/(landing-page)/page.tsx b/infrastructure/application/app/(landing-page)/page.tsx new file mode 100644 index 0000000..670dbf1 --- /dev/null +++ b/infrastructure/application/app/(landing-page)/page.tsx @@ -0,0 +1,13 @@ +import Hero from "@/components/app/hero"; +import HomeNavbar from "@/components/app/home-navbar"; + +export default function Home() { + + return ( +
+ + +
+ ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/BotAvatar.tsx b/infrastructure/application/app/(root)/(routes)/chat/components/BotAvatar.tsx new file mode 100644 index 0000000..6145d77 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/BotAvatar.tsx @@ -0,0 +1,15 @@ +import Image from "next/image" +const BotAvatar = () => { + + return ( + bot-image + ) +} + +export default BotAvatar \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/ChatContent.tsx b/infrastructure/application/app/(root)/(routes)/chat/components/ChatContent.tsx new file mode 100644 index 0000000..bf9f7af --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/ChatContent.tsx @@ -0,0 +1,95 @@ +"use client" + +import * as z from "zod"; +import { ChatBot, Message } from "@prisma/client" +import { useState } from "react" +import { useRouter } from "next/navigation" +import ChatMessages from "./ChatMessages" +import { Form, FormControl, FormField, FormItem } from "@/components/shadcn-ui/form"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { Input } from "@/components/shadcn-ui/input"; +import { formSchema } from "./constants"; +import { ArrowRight } from "lucide-react"; +import axios from "axios" +import { ChatCompletionRequestMessage } from "openai"; + +interface ChatContentProps { + chatBot: ChatBot & { + messages: Message[] + }, +} + +export const ChatContent = ({ chatBot }: ChatContentProps) => { + // prompt: user's input will be used later + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + prompt: "" + } + }); + + const isLoading = form.formState.isSubmitting; + const router = useRouter() + + // array of objects: [{role:.., prompt:...},object2] + const [messages, setMessages] = useState([]) + + const onSubmit = async (values: z.infer) => { + try { + // User's current message + const userMessage: ChatCompletionRequestMessage = { + role: "user", + content: values.prompt + }; + // Append user message + setMessages((current) => [...current, userMessage]) + const response = await axios.post(`/api/chat`, { + messages: userMessage + }); + + setMessages((current) => [...current, response.data]) + + // clear user input + form.reset() + } catch (error) { + console.log("Error submitting the form") + } finally { + router.refresh() + } + } + + + return ( +
+ +
+ + ( + + + + + + )} + /> + + + +
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/ChatMessage.tsx b/infrastructure/application/app/(root)/(routes)/chat/components/ChatMessage.tsx new file mode 100644 index 0000000..7336aad --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/ChatMessage.tsx @@ -0,0 +1,41 @@ +import { cn } from "@/lib/utils"; +import Image from "next/image"; +import { ChatCompletionRequestMessageRoleEnum } from "openai"; +import { BeatLoader } from "react-spinners" +import { UserAvatar } from "./UserAvatar"; +import BotAvatar from "./BotAvatar"; + +export interface ChatMessageProps { + role: ChatCompletionRequestMessageRoleEnum; + content?: string; + isLoading?: boolean; + src?: string +} + +const ChatMessage = ({ + role, + content, + isLoading, + src +}: ChatMessageProps) => { + return ( +
+ {role !== "user" && + + } +
+ {isLoading + ? + : content} +
+ {role === "user" && + + } +
+ ) +} + +export default ChatMessage \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/ChatMessages.tsx b/infrastructure/application/app/(root)/(routes)/chat/components/ChatMessages.tsx new file mode 100644 index 0000000..1241b9b --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/ChatMessages.tsx @@ -0,0 +1,59 @@ +"use client" + +import { ElementRef, useEffect, useRef, useState } from "react" +import ChatMessage, { ChatMessageProps } from "./ChatMessage" +import { ChatCompletionRequestMessage } from "openai"; + +interface ChatMessagesProps { + messages: ChatCompletionRequestMessage[] + isLoading: boolean +} + +/** + * This component renders a sequence of chat messages exchanged by the chatbot and the user. + */ +const ChatMessages = ({ messages, isLoading }: ChatMessagesProps) => { + const scrollRef = useRef>(null) + + const [fakeLoading, setFakeLoading] = useState(messages.length === 0 ? true : false) + + // make fake Loading last about 2 secs + useEffect(() => { + const timeout = setTimeout(() => { + setFakeLoading(false) + }, 2000) + return () => { + clearTimeout(timeout) + } + }, []) + + useEffect(() => ( + scrollRef?.current?.scrollIntoView({ behavior: "smooth" }) || undefined + ), [messages.length]) + return ( +
+ + {messages.map((item) => ( + + ))} + {isLoading && + } + + +
+
+ ) +} + +export default ChatMessages \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/SideContent.tsx b/infrastructure/application/app/(root)/(routes)/chat/components/SideContent.tsx new file mode 100644 index 0000000..182e3b5 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/SideContent.tsx @@ -0,0 +1,17 @@ +import React from 'react' +import BotAvatar from './BotAvatar' + +const SideContent = () => { + return ( +
+

1 Message

+
+ +

Dynamite ChatBot

+
+ +
+ ) +} + +export default SideContent \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/UserAvatar.tsx b/infrastructure/application/app/(root)/(routes)/chat/components/UserAvatar.tsx new file mode 100644 index 0000000..83a431c --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/UserAvatar.tsx @@ -0,0 +1,25 @@ +"use client"; + +import { useUser } from "@clerk/nextjs"; +import Image from "next/image"; + +/** + * This component renders the current user avatar. + * It will load the user information (containing URL to the user image) from clerk, then it renders the image using the loaded information. + */ +export const UserAvatar = () => { + const { user } = useUser(); + // Check if user information is available (not null). + // If not, simply return. + if (!user) return + // User information is available for rendering. + return ( + user image + ); +}; \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/components/constants.ts b/infrastructure/application/app/(root)/(routes)/chat/components/constants.ts new file mode 100644 index 0000000..0b836e1 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/components/constants.ts @@ -0,0 +1,7 @@ +import * as z from "zod"; + +export const formSchema = z.object({ + prompt: z.string().min(1, { + message: "Prompt is required." + }), +}); \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/chat/page.tsx b/infrastructure/application/app/(root)/(routes)/chat/page.tsx new file mode 100644 index 0000000..668e453 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/chat/page.tsx @@ -0,0 +1,26 @@ +import { ChatContent } from "./components/ChatContent" +import prismadb from "@/lib/prismadb" +import SideContent from "./components/SideContent" + +/** + * This components renders the chat bot page. + */ +const MessengerPage = async () => { + const chatBot = await prismadb.chatBot.findFirst({ + include: { + messages: true + } + }) + return ( +
+
+ +
+
+ +
+
+ ) +} + +export default MessengerPage \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion-container.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion-container.tsx new file mode 100644 index 0000000..09b0e60 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion-container.tsx @@ -0,0 +1,51 @@ +/** + * Component representing an Accordion container. + * + * This component renders an Accordion component containing multiple AccordionItem components, + * each displaying different content based on user interaction. + */ + +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/shadcn-ui/accordion" +import { + DoughnutChart, + PortfolioItem, + Watchlist +} from './index'; + +/** + * Functional component representing an Accordion container. + * + * @returns JSX.Element representing the Accordion container. + */ +export function AccordionContainer() { + return ( + + {/* AccordionItem for Portfolio. */} + + Porfolio + + + + + {/* AccordionItem for Watchlist. */} + + Watchlist + + + + + {/* AccordionItem for Distribution. */} + + Distribution + + + + + + ) +} diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/doughnut-chart.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/doughnut-chart.tsx new file mode 100644 index 0000000..70cb985 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/doughnut-chart.tsx @@ -0,0 +1,84 @@ +"use client" + +import 'chart.js/auto'; +import { Company, Portfolio } from '@prisma/client'; +import axios from 'axios'; +import Image from 'next/image'; +import { useEffect } from 'react'; +import { useQuery } from '@tanstack/react-query' +import Skeleton, { SkeletonTheme } from 'react-loading-skeleton'; +import { ThreeDots } from 'react-loading-icons' +import { Doughnut } from 'react-chartjs-2'; +import { Portfolio_Company } from '@prisma/client'; +export function DoughnutChart() { + const { data: queryData, isLoading } = useQuery({ + queryKey: ['getPortfolio'], + queryFn: async () => { + const response = await axios.get('/api/portfolio') + return response.data; + }, + staleTime: 3600000 // 1 hour in ms + }) + if (isLoading || !queryData) { + return ( +
+ ...Loading +
+ ); + } + + const arrSymbols = queryData.companies.map((item: Company) => item.symbol); + const arrData: number[] = queryData.companies.map((item: Portfolio_Company) => { + return item.shares * item.company.price + }) + const total = arrData.length === 0 ? 0 : arrData.length === 1 ? 1 : arrData.reduce((a, b) => a + b) + const arrPercent = total === 0 ? [] : total === 1 ? [100] : arrData.map((item) => Math.max(item / total * 100, 2)) + + const bgColors = ['#035380', '#00803d', '#803335', '#17807b', '#144f80'] + //const arrValues = queryData.map((item)=>item.) + const options = {} + const data = + { + labels: arrSymbols, + datasets: [{ + label: '% of Total', + data: arrPercent, + borderWidth: 4, + hoverOffset: 15, + hoverBorderColor: '#6149cd', + }] + }; + + + return ( +
+
+

{arrSymbols.length}

+

Stocks

+
+
+ +
+
+ ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/portfolio-item.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/portfolio-item.tsx new file mode 100644 index 0000000..7792b24 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/portfolio-item.tsx @@ -0,0 +1,68 @@ +"use client" + +import { Company, Portfolio } from '@prisma/client'; +import axios from 'axios'; +import Image from 'next/image'; +import { useEffect } from 'react'; +import { useQuery } from '@tanstack/react-query' +import Skeleton, { SkeletonTheme } from 'react-loading-skeleton'; +import { ThreeDots } from 'react-loading-icons' +import { Portfolio_Company } from '@prisma/client'; + +export const PortfolioItem = () => { + + const { data: queryData, isLoading } = useQuery({ + queryKey: ['getPortfolio'], + queryFn: async () => { + const response = await axios.get('/api/portfolio') + return response.data; + }, + }) + + if (isLoading || !queryData) { + return ( +
+ ...Loading +
+ ); + } + //const data: Company[] = queryData.companies.map((item: Portfolio_Company & Company) => item.company); + const data2 = queryData.companies + console.log(data2) + return ( + <> + { + !data2 || data2.length === 0 ? +
+ No stocks added. +
: +
+ {data2.map((company: Portfolio_Company & Company) => ( + + ))} +
+ } + + ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/watchlist.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/watchlist.tsx new file mode 100644 index 0000000..0254853 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/accordion/watchlist.tsx @@ -0,0 +1,63 @@ +"use client" + +import { Company } from "@prisma/client"; +import { useQuery } from "@tanstack/react-query"; +import axios from "axios"; +import { Heart } from "lucide-react"; +import Image from "next/image"; +export const Watchlist = () => { + + const { data, isLoading } = useQuery({ + queryKey: ['getWatchlist'], + queryFn: async () => { + const response = await axios.get('/api/watchlist') + return response.data; + }, + staleTime: 3600000 // 1 hour in ms only runs once when the component mounts + }) + + if (isLoading) { + return ( +
+ ...Loading +
+ ); + } + return ( + <> + {!data || data.length === 0 ? +
No watchlist added
: + +
+ {data.map((company) => ( +
+
+ stock img +
+

{company.yahooStockV2Summary.price.shortName}

+

just now

+
+
+
+

${company.price}

+

+ +

+
+
+ ))} +
+ } + + ) + +} diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/account-container.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/account-container.tsx new file mode 100644 index 0000000..4c8cbce --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/account-container.tsx @@ -0,0 +1,56 @@ +"use client" + +import { dashboardHeaders } from '@/constants' +import { AccountCard } from './index' +import { QueryClient, useQueryClient } from '@tanstack/react-query' +import { useQuery } from '@tanstack/react-query' +import axios from 'axios' +import { Skeleton } from '@/components/shadcn-ui/skeleton' +import { useAccount } from '@/hooks/use-account' +import { useEffect } from 'react' + +const AccountContainer = () => { + const { accountVal, setAccountVal } = useAccount() + const { data, isLoading } = useQuery({ + queryKey: ['getAccount2'], + queryFn: async () => { + const response = await axios.get('/api/account') + return response.data; + }, + }) + useEffect(() => { + if (!isLoading) { + setAccountVal(data.accountBalance) + } + }, [data]) + if (isLoading) { + return
+ {dashboardHeaders.map((item, index) => + + )} +
+ } + + const total = data.accountBalance + data.portfolio.portfolioVal + const dataArr = [ + total, + data.accountBalance, + data.portfolio.portfolioVal + ] + + return ( +
+ {dashboardHeaders.map((item, index) => + + )} +
+ ) +} + +export default AccountContainer \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/account/account-card.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/account/account-card.tsx new file mode 100644 index 0000000..0ab61b6 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/account/account-card.tsx @@ -0,0 +1,61 @@ +"use client" + +import { AiOutlineRise } from "react-icons/ai"; +import { MdAccountBalance, MdCardTravel } from "react-icons/md"; +import { LuWallet } from "react-icons/lu"; + +import { LineChart } from "./line-chart"; + +interface AccountCardProps { + title: string, + value: number, + percentChange: string + index: number +} + +const icons = [ + { + icon: MdAccountBalance + }, + { + icon: LuWallet + }, + { + icon: MdCardTravel + } +] + +export const AccountCard = ( + { title, value, percentChange, index }: AccountCardProps +) => { + + const selectedIcon = icons[index] + const formattedVal = value.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ","); + + return ( +
+ + +

+ {title} +

+

+ ${formattedVal} +

+ +
+
+ +

+30.23%

+
+ +
+
+ ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/account/baseline-chart.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/account/baseline-chart.tsx new file mode 100644 index 0000000..e69de29 diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/account/line-chart.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/account/line-chart.tsx new file mode 100644 index 0000000..e5fc042 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/account/line-chart.tsx @@ -0,0 +1,58 @@ +"use client" + +import { DeepPartial, LastPriceAnimationMode, TimeChartOptions, createChart } from 'lightweight-charts'; +import { useRef, useEffect } from 'react'; + +export const LineChart = () => { + const containerRef = useRef(null) + const upTrend = true; + useEffect(() => { + if (!containerRef.current) { + return; + } + const chartOptions: DeepPartial = { + layout: { + textColor: 'black', + background: { + color: 'transparent' + } + }, + grid: { + vertLines: { + visible: false + }, + horzLines: { + visible: false + } + }, + rightPriceScale: { + visible: false + }, + timeScale: { + visible: false + } + }; + const chart = createChart(containerRef.current, chartOptions); + + const areaSeries = chart.addLineSeries({ color: upTrend ? '#14a34a' : '#ea2c2b', lastPriceAnimation: LastPriceAnimationMode.Continuous, priceLineVisible: false }); + + areaSeries.setData([ + { time: '2018-12-22', value: 22.51 }, + { time: '2018-12-23', value: 23.11 }, + { time: '2018-12-24', value: 25.02 }, + { time: '2018-12-25', value: 27.32 }, + { time: '2018-12-26', value: 24.17 }, + { time: '2018-12-27', value: 23.89 }, + { time: '2018-12-28', value: 28.46 }, + { time: '2018-12-29', value: 29.92 }, + { time: '2018-12-30', value: 35.68 }, + { time: '2018-12-31', value: 40.67 }, + ]); + + chart.timeScale().fitContent(); + }, []) + + return ( +
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank-container.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank-container.tsx new file mode 100644 index 0000000..774d297 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank-container.tsx @@ -0,0 +1,15 @@ +"use client" +import { DropdownContent } from "./bank/dropdown-content" +import { Wrapper } from "./bank/wrapper" +export function BankContainer() { + return ( +
+
+

My Cards

+ + +
+ +
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/add-card.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/add-card.tsx new file mode 100644 index 0000000..b714150 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/add-card.tsx @@ -0,0 +1,187 @@ +"use client" +import { + DialogContent +} from "@/components/shadcn-ui/dialog" +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import axios from "axios"; +import { z } from "zod" +import toast from "react-hot-toast" +import { Button } from "@/components/shadcn-ui/button" +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/shadcn-ui/card" +import { Input } from "@/components/shadcn-ui/input" +import { Label } from "@/components/shadcn-ui/label" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + SelectGroup, + SelectLabel +} from "@/components/shadcn-ui/select" +import { useForm } from "react-hook-form" +import { zodResolver } from "@hookform/resolvers/zod"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, + FormDescription +} from "@/components/shadcn-ui/form"; +import { useOrder } from "@/hooks/use-order"; +import { ColorSelector } from "./color-selector"; +import { useColor } from '@/hooks/use-color' +const FormSchema = z.object({ + name: z.string().min(2, { + message: "Name must be at least 2 characters.", + }), + cardNumber: z.string().length(19, { + message: "Card Number must have 19 chracters including space.", + }).regex(/^[0-9]{4}\s[0-9]{4}\s[0-9]{4}\s[0-9]{4}$/, { + message: "Must be in format 'xxxx xxxx xxxx xxxx'" + }), + value: z.coerce.number().gte(5000, { + message: "Value must be from $5,000 to $500,000", + }).lte(500000, { + message: "Value must be from $5,000 to $500,000", + }), + expiryDate: z.string().regex(/^([0-9][0-9])\/([0-1][0-9])$/, { + message: "Must be in the format of yy/mm." + }) +}) + +export const AddCard = () => { + const { isOpen, setIsOpen } = useOrder() + const { color } = useColor() + const form = useForm>({ + resolver: zodResolver(FormSchema), + defaultValues: { + name: "", + cardNumber: "0000 0000 0000 0000", + value: 250000, + expiryDate: "", + }, + }) + + function onSubmit(data: z.infer) { + addCard(data) + } + + const queryClient = useQueryClient(); + const { mutate: addCard, isPending } = useMutation({ + mutationFn: (data: z.infer) => { + return axios.post("/api/card/add", { data, color }) + }, + onError: (error) => { + toast.error("Failed adding new card. Duplicate card found") + }, + onSuccess: () => { + toast.success("Add Card Successfully") + form.reset() + setIsOpen(!isOpen) + queryClient.invalidateQueries({ + queryKey: ['getCard'], + exact: true, + refetchType: 'all' + }) + } + }) + return ( + +
+ + + Card Information + Add your card with in one-click. + + +
+ + ( + + Name + + + + + + + )} + /> + + ( + + Card Number + + + + + + + )} + /> + +
+ ( + + Value + +
+

$

+ +
+
+ + +
+ )} + /> + + ( + + Exp Date + + + + + + + )} + /> +
+ +
+ + +
+ + +
+
+
+
+ + ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/card-content.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/card-content.tsx new file mode 100644 index 0000000..a8ef138 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/card-content.tsx @@ -0,0 +1,40 @@ +import { cn } from "@/lib/utils" +import { Card } from "@prisma/client" +import { useQueryClient } from "@tanstack/react-query" + +interface CardContentProps { + className?: string, + order?: string, + cardData: Card +} + +export const CardContent = ({ className, order, cardData }: CardContentProps) => { + const firstData = cardData + + const visibleDigits = firstData.cardDigits.slice(0, 3); + const lastTwoDigits = firstData.cardDigits.slice(-2); + const hiddenDigits = "* **** **** **"; + const formattedCardNumber = visibleDigits + hiddenDigits + lastTwoDigits; + return ( +
+
+

+ {firstData.holderName} + VISA +

+ +
+

{formattedCardNumber}

+ +

+ ${firstData.value} + {firstData.expiration} +

+
+
+
+ ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/color-selector.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/color-selector.tsx new file mode 100644 index 0000000..dd8ed36 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/color-selector.tsx @@ -0,0 +1,44 @@ +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectLabel, + SelectTrigger, + SelectValue, +} from "@/components/shadcn-ui/select" +import { useColor } from "@/hooks/use-color"; +import { cn } from "@/lib/utils"; +import { useEffect, useState } from "react"; +import { HexColorPicker } from "react-colorful"; +import { useRef } from "react"; + +export const ColorSelector = () => { + const { color, setColor } = useColor() + const [isOpen, setIsOpen] = useState(false); + const presetColors = ["#5b13f4", "#5b4ff4", "#bb4ff4", "#bb4f99", "#00b59d", "#b8b198"]; + + const handleClick = (input: string) => { + setColor(input) + } + return ( + + ) +} diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/deposit.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/deposit.tsx new file mode 100644 index 0000000..dcc7440 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/deposit.tsx @@ -0,0 +1,241 @@ +"use client" +import { + DialogContent +} from "@/components/shadcn-ui/dialog" +import { QueryClient, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import axios from "axios"; +import { z } from "zod" +import toast from "react-hot-toast" +import { Button } from "@/components/shadcn-ui/button" +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/shadcn-ui/card" +import { Input } from "@/components/shadcn-ui/input" +import { Label } from "@/components/shadcn-ui/label" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + SelectGroup, + SelectLabel +} from "@/components/shadcn-ui/select" +import { useForm } from "react-hook-form" +import { zodResolver } from "@hookform/resolvers/zod"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, + FormDescription +} from "@/components/shadcn-ui/form"; +import { useOrder } from "@/hooks/use-order"; +import { Card as CardModel } from "@prisma/client"; +import { TbArrowBarToRight } from "react-icons/tb"; +import { useState } from "react"; +import { useAccount } from "@/hooks/use-account"; +import { useToast } from "@/components/shadcn-ui/use-toast"; +import { useColor } from "@/hooks/use-color"; + +function getFormattedDigits(data: string) { + const visibleDigits = data.slice(0, 3); + const lastTwoDigits = data.slice(-4); + const hiddenDigits = "* **** **** "; + const formattedCardNumber = visibleDigits + hiddenDigits + lastTwoDigits; + return formattedCardNumber; +} + + +export const Deposit = ({ cardLists }: { cardLists: CardModel[] }) => { + const { color } = useColor() + const { toast: shadcnToast } = useToast() + const [shareLabel, setShareLabel] = useState('') + const { accountVal } = useAccount() + const { isOpen, setIsOpen } = useOrder() + + const FormSchema = z.object({ + cardNum: z + .string({ + required_error: "Please select a card.", + }), + value: z.coerce.number().gte(500, { + message: "Value must be from $500 to $50,000", + }).lte(50000, { + message: "Value must be from $500 to $50,000", + }), + }) + + const form = useForm>({ + resolver: zodResolver(FormSchema), + }) + function onSubmit(data: z.infer) { + deposit(data) + } + + const queryClient = useQueryClient() + const { mutate: deposit, isPending } = useMutation({ + mutationFn: (data: z.infer) => { + return axios.patch("/api/card/deposit", data) + }, + onError: (error) => { + console.log(error) + shadcnToast({ + variant: "destructive", + title: "Insufficient card balance.", + }) + queryClient.invalidateQueries({ + queryKey: ['getTransaction'], + exact: true, + refetchType: 'all' + }) + }, + onSuccess: () => { + toast.success("Deposit Successfully") + setIsOpen(!isOpen) + queryClient.invalidateQueries({ + queryKey: ['getCard'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getAccount2'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getTransaction'], + exact: true, + refetchType: 'all' + }) + form.reset() + } + }) + const getCardVal = (data: string) => { + const foundData = cardLists.find((item) => item.cardDigits === data) + return foundData ? foundData.value : "" + } + const getCardExp = (data: string) => { + const foundData = cardLists.find((item) => item.cardDigits === data) + return foundData ? foundData.expiration : "" + } + if (!cardLists || cardLists.length === 0) { + toast.error("Please add card to deposit !") + return; + } + + return ( + +
+ + + Deposit Money + Deposit to your account with in one-click. + + + +
+ + ( + + + + You can manage your cards in the Card Section + + + {field.value && + <> +
+
+

Card Value:

+

${getCardVal(field.value)}

+
+
+

Expiration:

+

{getCardExp(field.value)}

+
+
+ +
+ ( + + Amount + +
+

$

+ { + field.onChange(e) + setShareLabel(`${e.target.value ? parseFloat(e.target.value) + accountVal : accountVal}`); + }} + /> +
+
+ + +
+ )} + /> + +
+
Account Value
+
${shareLabel}
+
+
+ + } +
+ + +
+
+ )} + /> + + +
+ +
+ +
+
+ ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/dropdown-content.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/dropdown-content.tsx new file mode 100644 index 0000000..d9cdcf0 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/dropdown-content.tsx @@ -0,0 +1,103 @@ +"use client" +import { + DropdownMenu, + DropdownMenuTrigger, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator +} from "@/components/shadcn-ui/dropdown-menu" +import { Deposit } from "./deposit" +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, + DialogFooter, + DialogOverlay +} from "@/components/shadcn-ui/dialog" +import { useState } from "react" +import { Button } from "@/components/shadcn-ui/button" +import { MoreHorizontal } from "lucide-react" +import { useOrder } from "@/hooks/use-order" +import { Withdraw } from "./withdraw" +import { AddCard } from "./add-card" +import { SwitchCard } from "./switch-card" +import { RemoveCard } from "./remove-card" +import { useQuery } from "@tanstack/react-query" +import axios from "axios" +import { useToast } from "@/components/shadcn-ui/use-toast" +import { ViewCards } from "./view-cards" +import { Skeleton } from "@/components/shadcn-ui/skeleton" +export const DropdownContent = () => { + const { order, setOrder, isOpen, setIsOpen } = useOrder() + const { toast } = useToast() + const { data, isLoading } = useQuery({ + queryKey: ['getCard'], + queryFn: async () => { + const result = await axios.get("/api/card") + return result.data; + } + }) + + if (isLoading) return ( + + ) + return ( +
+
+ {data.length} +
+ + + + + + + setOrder("1")}> + View all cards + + + + setOrder("2")}> + Add card + + { + setOrder("3") + !data || data.length === 0 && + toast({ + variant: "destructive", + title: "No card found.", + }) + }}> + Remove card + + + + setOrder("4")} > + Deposit + + setOrder("5")} > + Withdraw + + + + + { + order === '0' ? <> : + order === '1' ? : + order === '2' ? : + order === '3' ? : + order === '4' ? : + + } + +
+ + ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/remove-card.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/remove-card.tsx new file mode 100644 index 0000000..48ec159 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/remove-card.tsx @@ -0,0 +1,184 @@ +"use client" +import { + DialogContent +} from "@/components/shadcn-ui/dialog" +import { QueryClient, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import axios from "axios"; +import { z } from "zod" +import toast from "react-hot-toast" +import { Button } from "@/components/shadcn-ui/button" +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/shadcn-ui/card" +import { Input } from "@/components/shadcn-ui/input" +import { Label } from "@/components/shadcn-ui/label" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + SelectGroup, + SelectLabel +} from "@/components/shadcn-ui/select" +import { useForm } from "react-hook-form" +import { zodResolver } from "@hookform/resolvers/zod"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, + FormDescription +} from "@/components/shadcn-ui/form"; +import { useOrder } from "@/hooks/use-order"; +import { Card as CardModel } from "@prisma/client"; +import { useToast } from "@/components/shadcn-ui/use-toast"; + +const FormSchema = z.object({ + cardNum: z + .string({ + required_error: "Please select a card.", + }), +}) + +function getFormattedDigits(data: string) { + const visibleDigits = data.slice(0, 3); + const lastTwoDigits = data.slice(-4); + const hiddenDigits = "* **** **** "; + const formattedCardNumber = visibleDigits + hiddenDigits + lastTwoDigits; + return formattedCardNumber; +} + + +export const RemoveCard = ({ cardLists }: { cardLists: CardModel[] }) => { + const { isOpen, setIsOpen } = useOrder() + const { toast: shadcnToast } = useToast() + const form = useForm>({ + resolver: zodResolver(FormSchema), + }) + function onSubmit(data: z.infer) { + console.log(data) + addCard(data) + } + // const { data: cardLists, isLoading } = useQuery({ + // queryKey: ['getCard2'], + // queryFn: async () => { + // const result = await axios.get("/api/card") + // return result.data; + // } + // }) + // if (isLoading) return ( + //
...Loading
+ // ) + + + const queryClient = useQueryClient() + const { mutate: addCard, isPending } = useMutation({ + mutationFn: (data: z.infer) => { + return axios.post("/api/card/remove", data) + }, + onError: (error) => { + console.log(error) + }, + onSuccess: () => { + toast.success("Remove Card Successfully") + form.reset() + setIsOpen(!isOpen) + queryClient.invalidateQueries({ + queryKey: ['getCard'], + exact: true, + refetchType: 'all' + }) + } + }) + const getCardVal = (data: string) => { + const foundData = cardLists.find((item) => item.cardDigits === data) + return foundData ? foundData.value : "" + } + const getCardExp = (data: string) => { + const foundData = cardLists.find((item) => item.cardDigits === data) + return foundData ? foundData.expiration : "" + } + if (!cardLists || cardLists.length === 0) { + toast.error("No Cards Found!") + return; + } + + return ( + +
+ + + Remove Card + Remove your card with in one-click. + + + +
+ + ( + + + + You can manage your cards in the Card Section + + + {field.value && +
+
+

Card Value:

+

${getCardVal(field.value)}

+
+
+

Expiration:

+

{getCardExp(field.value)}

+
+
} +
+ + +
+
+ )} + /> + + +
+ +
+ +
+
+ + ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/switch-card.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/switch-card.tsx new file mode 100644 index 0000000..ced7b76 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/switch-card.tsx @@ -0,0 +1,33 @@ +"use client" +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, + DialogFooter, + DialogOverlay +} from "@/components/shadcn-ui/dialog" + +import { Button } from "@/components/shadcn-ui/button" + +export const SwitchCard = () => { + return ( + + + Are you absolutely sure? + + This action cannot be undone. Are you sure you want to permanently + delete this file from our servers? + + + + + + + + + ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/view-cards.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/view-cards.tsx new file mode 100644 index 0000000..0d193b5 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/view-cards.tsx @@ -0,0 +1,56 @@ +import { Label } from "@/components/shadcn-ui/label" +import toast from "react-hot-toast" +import { Button } from "@/components/shadcn-ui/button" +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, + DialogFooter, + DialogOverlay +} from "@/components/shadcn-ui/dialog" +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/shadcn-ui/card" +import { cn } from "@/lib/utils" +import { useState } from "react" +import { Card as CardModel } from "@prisma/client" +import { CardContent as CardContainer } from "./card-content" +import { Separator } from "@/components/shadcn-ui/separator" + +interface ViewCardsProps { + cardLists: CardModel[] +} + +export const ViewCards = ({ cardLists }: ViewCardsProps) => { + const [isOpen, setIsOpen] = useState(false) + + return ( + +
+ + + Available Cards + manage your cards under my cards section. + + + +
+ {cardLists.map((item) => + + )} +
+
+
+
+
+ + ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/withdraw.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/withdraw.tsx new file mode 100644 index 0000000..8a3e03e --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/withdraw.tsx @@ -0,0 +1,246 @@ +"use client" +import { + DialogContent +} from "@/components/shadcn-ui/dialog" +import { QueryClient, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import axios from "axios"; +import { z } from "zod" +import toast from "react-hot-toast" +import { Button } from "@/components/shadcn-ui/button" +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle, +} from "@/components/shadcn-ui/card" +import { Input } from "@/components/shadcn-ui/input" +import { Label } from "@/components/shadcn-ui/label" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, + SelectGroup, + SelectLabel +} from "@/components/shadcn-ui/select" +import { useForm } from "react-hook-form" +import { zodResolver } from "@hookform/resolvers/zod"; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, + FormDescription +} from "@/components/shadcn-ui/form"; +import { useOrder } from "@/hooks/use-order"; +import { Card as CardModel } from "@prisma/client"; +import { TbArrowBarToRight } from "react-icons/tb"; +import { useState } from "react"; +import { useAccount } from "@/hooks/use-account"; +import { useToast } from "@/components/shadcn-ui/use-toast"; +const FormSchema = z.object({ + cardNum: z + .string({ + required_error: "Please select a card.", + }), + value: z.coerce.number().gte(500, { + message: "Value must be from $500 to $50,000", + }).lte(50000, { + message: "Value must be from $500 to $50,000", + }), +}) + +function getFormattedDigits(data: string) { + const visibleDigits = data.slice(0, 3); + const lastTwoDigits = data.slice(-4); + const hiddenDigits = "* **** **** "; + const formattedCardNumber = visibleDigits + hiddenDigits + lastTwoDigits; + return formattedCardNumber; +} + + +export const Withdraw = ({ cardLists }: { cardLists: CardModel[] }) => { + const { toast: shadcnToast } = useToast() + const [shareLabel, setShareLabel] = useState('') + const { accountVal } = useAccount() + const { isOpen, setIsOpen } = useOrder() + + const form = useForm>({ + resolver: zodResolver(FormSchema), + }) + function onSubmit(data: z.infer) { + withdraw(data) + } + // const { data: cardLists, isLoading } = useQuery({ + // queryKey: ['getCard2'], + // queryFn: async () => { + // const result = await axios.get("/api/card") + // return result.data; + // } + // }) + // if (isLoading) return ( + //
...Loading
+ // ) + + + const queryClient = useQueryClient() + const { mutate: withdraw, isPending } = useMutation({ + mutationFn: (data: z.infer) => { + return axios.patch("/api/card/withdraw", data) + }, + onError: (error) => { + shadcnToast({ + variant: "destructive", + title: "Insufficient account balance.", + }) + queryClient.invalidateQueries({ + queryKey: ['getTransaction'], + exact: true, + refetchType: 'all' + }) + console.log(error) + }, + onSuccess: () => { + toast.success("Withdraw Successfully") + setIsOpen(!isOpen) + queryClient.invalidateQueries({ + queryKey: ['getCard'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getAccount2'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getTransaction'], + exact: true, + refetchType: 'all' + }) + form.reset() + } + }) + const getCardVal = (data: string) => { + const foundData = cardLists.find((item) => item.cardDigits === data) + return foundData ? foundData.value : "" + } + const getCardExp = (data: string) => { + const foundData = cardLists.find((item) => item.cardDigits === data) + return foundData ? foundData.expiration : "" + } + if (!cardLists || cardLists.length === 0) { + toast.error("Please add card to withdraw!") + return; + } + return ( + +
+ + + Withdraw Money + Withdraw money with in one-click. + + + +
+ + ( + + + + You can manage your cards in the Card Section + + + {field.value && + <> +
+
+

Card Value:

+

${getCardVal(field.value)}

+
+
+

Expiration:

+

{getCardExp(field.value)}

+
+
+ +
+
+
Account Value
+
${shareLabel}
+
+ + ( + + Amount + +
+

$

+ { + field.onChange(e) + setShareLabel(`${e.target.value ? accountVal - parseFloat(e.target.value) : accountVal}`); + }} + /> +
+
+ + +
+ )} + /> +
+ + } +
+ + +
+
+ )} + /> + + +
+
+
+
+ ) +} + diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/wrapper.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/wrapper.tsx new file mode 100644 index 0000000..7d5917d --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/bank/wrapper.tsx @@ -0,0 +1,47 @@ +import { useQuery } from "@tanstack/react-query" +import axios from "axios" +import { CardContent } from "./card-content" +import { Card } from "@prisma/client" +import 'react-loading-skeleton/dist/skeleton.css' +import { Skeleton } from "@/components/shadcn-ui/skeleton" + +const contents = [ + { + className: "p-2", + order: "1", + }, + { + className: "px-2 absolute -z-[2] w-full top-2 -right-2 opacity-80", + order: "2" + } +] +export const Wrapper = () => { + + const { data, isLoading } = useQuery({ + queryKey: ['getCard'], + queryFn: async () => { + const result = await axios.get("/api/card") + return result.data; + } + }) + if (isLoading) return ( + + ) + return ( + <> + {!data || data.length === 0 ? +
+ No cards added. +
: +
+ {data.slice(0, 2).map((item, index) => ( + + ))} + +
+ } + + ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/index.ts b/infrastructure/application/app/(root)/(routes)/dashboard/components/index.ts new file mode 100644 index 0000000..a9c25a5 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/index.ts @@ -0,0 +1,36 @@ +// Import container components +import { BankContainer } from "./bank-container"; +import AccountContainer from "./account-container"; +import { AccordionContainer } from "./accordion-container"; +import TableContainer from "./table-container"; + +// Import table components +import { ColumnHeader } from "./table/column-header"; +import { columns } from "./table/columns" +import { DataTable } from "./table/data-table"; +import { DataTablePagination } from "./table/data-table-pagination"; + +// Import account components +import { AccountCard } from "./account/account-card"; + +// Import accordion components +import { Watchlist } from "./accordion/watchlist"; +import { DoughnutChart } from "./accordion/doughnut-chart"; +import { PortfolioItem } from "./accordion/portfolio-item"; + +export { + BankContainer, + AccordionContainer, + AccountContainer, + TableContainer, + + ColumnHeader, + columns, + DataTable, + DataTablePagination, + AccountCard, + + Watchlist, + DoughnutChart, + PortfolioItem +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/table-container.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/table-container.tsx new file mode 100644 index 0000000..a527d35 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/table-container.tsx @@ -0,0 +1,38 @@ +"use client" +import { columns } from "./table/columns" +import { DataTable } from "./index" +import { useQuery } from "@tanstack/react-query"; +import axios from "axios"; +import { Skeleton } from "@/components/shadcn-ui/skeleton"; +import { Transaction } from "@prisma/client"; +import { Greetings } from "./table/greetings"; + + +const TableContainer = () => { + const { data: transactionData, isLoading } = useQuery({ + queryKey: ['getTransaction'], + queryFn: async () => { + const response = await axios.get("/api/transaction") + return response.data; + } + }) + if (isLoading || !transactionData) return ( +
+ +
+ ); + return ( +
+ + +
+ + ) +} + +export default TableContainer \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/table/column-header.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/column-header.tsx new file mode 100644 index 0000000..078ce84 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/column-header.tsx @@ -0,0 +1,72 @@ +import { + ArrowDownIcon, + ArrowUpIcon, + CaretSortIcon, + EyeNoneIcon, +} from "@radix-ui/react-icons" + +import { Column } from "@tanstack/react-table" + +import { cn } from "@/lib/utils" +import { Button } from "@/components/shadcn-ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/shadcn-ui/dropdown-menu" + +interface DataTableColumnHeaderProps + extends React.HTMLAttributes { + column: Column + title: string +} + +export function ColumnHeader({ + column, + title, + className, +}: DataTableColumnHeaderProps) { + if (!column.getCanSort()) { + return
{title}
+ } + + return ( +
+ + + + + + column.toggleSorting(false)}> + + Asc + + column.toggleSorting(true)}> + + Desc + + + column.toggleVisibility(false)}> + + Hide + + + +
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/table/columns.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/columns.tsx new file mode 100644 index 0000000..5715ed3 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/columns.tsx @@ -0,0 +1,92 @@ +"use client" +import { ArrowUpDown, MoreHorizontal } from "lucide-react" +import { ColumnDef } from "@tanstack/react-table" +import { ColumnHeader } from "./column-header" + +import { Button } from "@/components/shadcn-ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/shadcn-ui/dropdown-menu" +import { Checkbox } from "@/components/shadcn-ui/checkbox" +import { Transaction } from "@prisma/client" +import Image from "next/image" + + +export const columns: ColumnDef[] = [ + { + id: "select", + header: ({ table }) => ( + table.toggleAllPageRowsSelected(!!value)} + aria-label="Select all" + /> + ), + cell: ({ row }) => ( + row.toggleSelected(!!value)} + aria-label="Select row" + /> + ), + enableSorting: false, + enableHiding: false, + }, + { + accessorKey: "status", + header: "Status", + cell: ({ row }) => { + return ( +
+ {row.getValue('status')} +
+ ) + } + }, + { + // key: the data model collumn name + accessorKey: "id", + header: "Transaction ID" + }, + { + accessorKey: "type", + header: "Type", + cell: ({ row }) => { + const type: string = row.getValue("type"); + const firstType = type.split(" ")[0] + return ( +
+ {type} +
) + } + }, + { + accessorKey: "amount", + header: "Amount", + cell: ({ row }) => { + const type: string = row.getValue("type"); + const firstType = type.split(" ")[0] + const operator = (type === 'withdraw' || firstType === 'buy') ? '-' : '+'; + return
{operator}${row.getValue('amount')}
+ } + }, + { + accessorKey: "createdAt", + header: "Date", + cell: ({ row }) => { + const value = new Date(row.getValue('createdAt')) + return
{value.toLocaleString()}
+ } + }, + +] diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table-pagination.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table-pagination.tsx new file mode 100644 index 0000000..cb4be66 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table-pagination.tsx @@ -0,0 +1,98 @@ +import { + ChevronLeftIcon, + ChevronRightIcon, + DoubleArrowLeftIcon, + DoubleArrowRightIcon, +} from "@radix-ui/react-icons" +import { Table } from "@tanstack/react-table" + +import { Button } from "@/components/shadcn-ui/button" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/shadcn-ui/select" + +interface DataTablePaginationProps { + table: Table +} + +export function DataTablePagination({ + table, +}: DataTablePaginationProps) { + return ( +
+
+ {table.getFilteredSelectedRowModel().rows.length} of{" "} + {table.getFilteredRowModel().rows.length} row(s) selected. +
+
+
+

Rows per page

+ +
+
+ Page {table.getState().pagination.pageIndex + 1} of{" "} + {table.getPageCount()} +
+
+ + + + +
+
+
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table.tsx new file mode 100644 index 0000000..d34e042 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/data-table.tsx @@ -0,0 +1,172 @@ +"use client" + +import * as React from "react" +import { + ColumnDef, + flexRender, + getCoreRowModel, + useReactTable, + getPaginationRowModel, + SortingState, + getSortedRowModel, + ColumnFiltersState, + getFilteredRowModel, + VisibilityState, +} from "@tanstack/react-table" + +import { + DropdownMenu, + DropdownMenuCheckboxItem, + DropdownMenuContent, + DropdownMenuTrigger, +} from "@/components/shadcn-ui/dropdown-menu" +import { Checkbox } from "@/components/shadcn-ui/checkbox" +import { Input } from "@/components/shadcn-ui/input" +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/shadcn-ui/table" +import { Button } from "@/components/shadcn-ui/button" +import { DataTablePagination } from "./data-table-pagination" +import { SlidersHorizontal } from "lucide-react" + + +interface DataTableProps { + columns: ColumnDef[] + data: TData[] +} + +export function DataTable({ + columns, + data, +}: DataTableProps) { + + const [sorting, setSorting] = React.useState([]) + const [columnFilters, setColumnFilters] = React.useState( + [] + ) + const [columnVisibility, setColumnVisibility] = + React.useState({}) + const [rowSelection, setRowSelection] = React.useState({}) + + const table = useReactTable({ + data, + columns, + getCoreRowModel: getCoreRowModel(), + getPaginationRowModel: getPaginationRowModel(), + onSortingChange: setSorting, + getSortedRowModel: getSortedRowModel(), + onColumnFiltersChange: setColumnFilters, + getFilteredRowModel: getFilteredRowModel(), + onColumnVisibilityChange: setColumnVisibility, + onRowSelectionChange: setRowSelection, + state: { + sorting, + columnFilters, + columnVisibility, + rowSelection, + }, + }) + + + return ( +
+
+
+ + table.getColumn("id")?.setFilterValue(event.target.value) + } + className="max-w-[300px]" + /> + + + + + + {table + .getAllColumns() + .filter( + (column) => column.getCanHide() + ) + .map((column) => { + return ( + + column.toggleVisibility(!!value) + } + > + {column.id} + + ) + })} + + +
+
+ + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => { + return ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext() + )} + + ) + })} + + ))} + + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + )) + ) : ( + + + No transaction found. + + + ) + } + +
+
+
+ +
+
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/components/table/greetings.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/greetings.tsx new file mode 100644 index 0000000..66c3110 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/components/table/greetings.tsx @@ -0,0 +1,18 @@ +import { useUser } from "@clerk/nextjs" +import { redirect } from "next/navigation" + +export const Greetings = () => { + const user = useUser() + if (!user) { + redirect('/') + } + const firstName = user.user?.firstName + const lastName = user.user?.lastName + console.log(user) + return ( +
+

Welcome Back, {firstName} {lastName}

+

Here is the list of your recent transactions

+
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/dashboard/page.tsx b/infrastructure/application/app/(root)/(routes)/dashboard/page.tsx new file mode 100644 index 0000000..aded2ea --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/dashboard/page.tsx @@ -0,0 +1,58 @@ +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; +import { + AccountContainer, + BankContainer, + TableContainer, + AccordionContainer +} from "./components/index" +import { redirect } from "next/navigation"; +import { Company, Watchlist, Watchlist_Company } from "@prisma/client"; + +const DashboardPage = async () => { + const { userId } = auth() + if (!userId) { + redirect("/") + } + + // create new user account, skip if already exists + await prismadb.account.upsert({ + where: { + id: userId + }, + update: {}, + create: { + id: userId, + accountBalance: 0, + accountValue: 0, + portfolio: { + create: { + portfolioVal: 0 + } + }, + watchlist: { + create: { + name: "default" + } + } + } + }) + + return ( +
+
+ + +
+ +
+
+ + +
+
+
+ ) +} + +export default DashboardPage \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/company-profile/back-btn.tsx b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/back-btn.tsx new file mode 100644 index 0000000..da60f3e --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/back-btn.tsx @@ -0,0 +1,23 @@ +import { useAnimation } from "@/hooks/use-animation"; +import { MoveLeft } from "lucide-react"; +import { cn } from "@/lib/utils"; + +const BackBtn = () => { + const { animatedId, setAnimatedId } = useAnimation() + + return ( + + ) +} + +export default BackBtn \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-details.tsx b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-details.tsx new file mode 100644 index 0000000..fda3cf7 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-details.tsx @@ -0,0 +1,222 @@ +import { Company } from "@prisma/client" + +interface CompanyDetailsProps { + foundCompany: Company +} + +type financialData = { + ebitda: { fmt: string }, + totalCash: { fmt: string }, + totalDebt: { fmt: string }, + quickRatio: { fmt: string }, + currentPrice: { fmt: string }, + currentRatio: { fmt: string }, + debtToEquity: { fmt: string }, + freeCashflow: { fmt: string }, + grossMargins: { fmt: string }, + totalRevenue: { fmt: string }, + ebitdaMargins: { fmt: string }, + profitMargins: { fmt: string }, + revenueGrowth: { fmt: string }, + earningsGrowth: { fmt: string }, + returnOnAssets: { fmt: string }, + returnOnEquity: { fmt: string }, + targetLowPrice: { fmt: string }, + revenuePerShare: { fmt: string }, + targetHighPrice: { fmt: string }, + targetMeanPrice: { fmt: string }, + operatingMargins: { fmt: string }, + operatingCashflow: { fmt: string }, + targetMedianPrice: { fmt: string }, + totalCashPerShare: { fmt: string }, + recommendationMean: { fmt: string }, + numberOfAnalystOpinions: { fmt: string }, +} + +export type insiderHolder = { + holders: [ + { + name: string, + relation: string, + positionDirect: { + fmt: string, + } + } + ] +} +const CompanyDetails = ({ foundCompany }: CompanyDetailsProps) => { + + const targetData = foundCompany.yahooStockV2Summary + const insiderHolders: insiderHolder = targetData.insiderHolders + const financialData: financialData = targetData.financialData + + const overview = targetData.price; + const summary = targetData.summaryProfile + const unixTimestamp = targetData.price.regularMarketTime + const regularTime = new Date(unixTimestamp * 1000).toUTCString() + const headers = [ + "Sector", + "Industry", + "Fulltime Employees", + "Website", + "Address", + + "Summary", + + "Ebitda", + "TotalCash", + "TotalDebt", + "QuickRatio", + "CurrentPrice", + "CurrentRatio", + "DebtToEquity", + "FreeCashflow", + "GrossMargin", + "TotalRevenue", + + "EbitdaMargins", + "ProfitMargins", + "RevenueGrowth", + "EarningsGrowth", + "ReturnOnAssets", + "ReturnOnEquity", + "TargetLowPrice", + + "RevenuePerShare", + "TargetHighPric", + "TargetMeanPrice", + "OperatingMargins", + "OperatingCashflo", + "TargetMedianPrice", + "TotalCashPerShare", + "RecommendationMean", + "NumberOfAnalystOpinions", + + "Name", + "Role", + "Shares", + + "Company", + "Zip", + "City", + "Phone", + "State", + "Country", + "Industry", + "Sector", + "Description", + ] + + const address = summary.address1 + ", " + summary.city + ", " + summary.state + ", " + summary.zip; + const data = [ + summary.sector, + summary.industry, + summary.fullTimeEmployees, + summary.website, + address, + + summary.longBusinessSummary, + + financialData.ebitda.fmt, + financialData.totalCash.fmt, + financialData.totalDebt.fmt, + financialData.quickRatio.fmt, + financialData.currentPrice.fmt, + financialData.currentRatio.fmt, + financialData.debtToEquity.fmt, + financialData.freeCashflow.fmt, + financialData.grossMargins.fmt, + financialData.totalRevenue.fmt, + + financialData.ebitdaMargins.fmt, + financialData.profitMargins.fmt, + financialData.revenueGrowth.fmt, + financialData.earningsGrowth.fmt, + financialData.returnOnAssets.fmt, + financialData.returnOnEquity.fmt, + financialData.targetLowPrice.fmt, + financialData.revenuePerShare.fmt, + financialData.targetHighPrice.fmt, + financialData.targetMeanPrice.fmt, + financialData.operatingMargins.fmt, + financialData.operatingCashflow.fmt, + financialData.targetMedianPrice.fmt, + financialData.totalCashPerShare.fmt, + financialData.recommendationMean.fmt, + financialData.numberOfAnalystOpinions.fmt, + + insiderHolders.holders, + + overview.longName, + summary.zip, + summary.city, + summary.phone, + summary.state, + summary.country, + ] + + return ( + <> +
+

Updated at {regularTime}

+
Company Overview
+
+ {headers.slice(0, 5).map((item, index) => ( +
+

{item}

+

{data[index]}

+
+ ))} +
{data[5]}
+
+
+
+
Financial Data
+
+
+
Valuation
+ {headers.slice(6, 15).map((item, index) => ( +
+

{item}

+

{data[index + 6]}

+
+ ))} +
+ +
+
Revenue
+ {headers.slice(15, 25).map((item, index) => ( +
+

{item}

+

{data[index + 15]}

+
+ ))} +
+
+
+ +
Top Stakeholders
+
+ {insiderHolders.holders.map((item) => ( +
+

{item.name}

+
+

role

+

{item.relation}

+
+
+

Shares

+ {item.positionDirect?.fmt ? +

{item.positionDirect?.fmt}

: +

No public information available

+ } + +
+
+ ))} +
+ + ) +} + +export default CompanyDetails \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-profile.tsx b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-profile.tsx new file mode 100644 index 0000000..f316892 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/company-profile.tsx @@ -0,0 +1,125 @@ +import { Company } from "@prisma/client" +import Image from "next/image" +import Link from "next/link" +import BackBtn from "./back-btn" +import CompanyDetails, { insiderHolder } from "./company-details" +import EmployeeCard from "./employee-card" +import ExecutiveCard from "./executive-card" +export interface CompanyProfile { + company: Company +} + + + +export const CompanyProfile = ({ company }: CompanyProfile) => { + const insiderHolders = company.yahooStockV2Summary.insiderHolders + const arrHolders = insiderHolders.holders + + const fundOwnership = company.yahooStockV2Summary.fundOwnership.ownershipList[0] + + const numOfEmployees = [ + { + amount: company.yahooStockV2Summary.summaryProfile.fullTimeEmployees + '+', + name: "Employees" + }, + { + amount: '$' + company.yahooStockV2Summary.financialData.totalRevenue.fmt, + name: "Revenue" + }, + { + amount: "90+", + name: "Staff" + } + + ] + const executives = [ + { + name: arrHolders[0]?.name || "N/A", + role: arrHolders[0]?.relation || "", + shares: arrHolders[0]?.positionDirect?.fmt || "N/A", + avatarImg: "/avatars/ava1.png", + }, + { + name: arrHolders[1]?.name || "N/A", + role: arrHolders[1]?.relation || "", + shares: arrHolders[1]?.positionDirect?.fmt || "N/A", + avatarImg: "/avatars/ava2.png", + }, + { + name: arrHolders[2]?.name || "N/A", + role: arrHolders[2]?.relation || "", + shares: arrHolders[2]?.positionDirect?.fmt || "N/A", + avatarImg: "/avatars/ava3.png", + }, + ] + return ( + +
+
+
+ company logo +
+

{company.yahooStockV2Summary.price.shortName}

+

{company.symbol}

+
+
+ +
+
+
+ +
+
+
+
+

Top Holding

+

{fundOwnership.position.fmt}

+
+
+

${fundOwnership.value.fmt}

+

{fundOwnership.pctHeld.fmt}

+
+
+ CEO image +
+

{fundOwnership.organization}

+
+ +
+ +
+
+ {executives.map((item, index) => ( + + ))} +
+
+ {numOfEmployees.map((item, index) => ( + + ))} +
+
+
+
+
+ +
+
+
+ + ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/company-profile/employee-card.tsx b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/employee-card.tsx new file mode 100644 index 0000000..a3ab8a4 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/employee-card.tsx @@ -0,0 +1,20 @@ +import React from 'react' + +interface EmployeeCardProps { + data: { + amount: string, + name: string + } +} + +const EmployeeCard = ({ data }: EmployeeCardProps) => { + const { amount, name } = data; + return ( +
+

{name}

+

{amount}

+
+ ) +} + +export default EmployeeCard \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/company-profile/executive-card.tsx b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/executive-card.tsx new file mode 100644 index 0000000..8a4fa95 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/company-profile/executive-card.tsx @@ -0,0 +1,30 @@ +import Image from "next/image"; + +interface ExecutiveProps { + employee: { + name: string, + role: string, + avatarImg: string, + shares: string + } +} + +const ExecutiveCard = ({ employee }: ExecutiveProps) => { + const { name, role, avatarImg, shares } = employee; + return ( +
+ avatar +

{name}

+

{role}

+

{shares}

+
+ ) +} + +export default ExecutiveCard \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/market-container.tsx b/infrastructure/application/app/(root)/(routes)/market/components/market-container.tsx new file mode 100644 index 0000000..88898e3 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/market-container.tsx @@ -0,0 +1,85 @@ +"use client" + +import { useAnimate } from "framer-motion" +import { useEffect } from "react" + +import { useAnimation } from "@/hooks/use-animation" +import { useTicker } from "@/hooks/use-ticker" +import { Company } from "@prisma/client" +import { CompanyProfile } from "./company-profile/company-profile" +import { FeaturedProduct } from "./products/featured-product" +import TableContainer from "./table/table-container" +export interface MarketContainerProps { + companies: Company[] +} + +const MarketContainer = ({ companies }: MarketContainerProps) => { + const { ticker } = useTicker() + const foundCompany = companies.find((item: Company) => item.symbol === ticker) + const updatedAt = new Date(foundCompany.yahooStockV2Summary.price.regularMarketTime).toLocaleTimeString() + const { animatedId, setAnimatedId, firstLoop, setFirstLoop } = useAnimation() + const [frontElement, animate] = useAnimate(); + const [behindElement, animate1] = useAnimate(); + // Animation + useEffect(() => { + if (firstLoop) { + setFirstLoop(false); + return; + } + if (animatedId === 1) { + console.log(animatedId) + const windowWidth = window.innerWidth; + const xValue = -65 * (windowWidth / 100); + animate(frontElement.current, { opacity: 0, scale: .8, x: xValue }, { duration: .2 }); + animate1(behindElement.current, { scale: [.8, 1], x: [-xValue, 0] }, { duration: .2 }); + } + if (animatedId === 2) { + const windowWidth = window.innerWidth; + const xValue = 65 * (windowWidth / 100); + animate(frontElement.current, { opacity: 1, scale: 1, x: 0 }, { duration: .2 }); + animate1(behindElement.current, { scale: [1, .8], x: xValue }, { duration: .2 }); + } + }, [animatedId, firstLoop]) + // useEffect(() => { + // setAnimatedId(2) + // }, [pathname]) + return ( +
+ {/*Left Container*/} +
+ +
+ + {/*Right Container*/} +
+
+
+

Companies

+
+ Updated {updatedAt} +
+
+
+ +
+
+ +
+ +
+
+
+ ) +} + +export default MarketContainer \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/products/bar-chart.tsx b/infrastructure/application/app/(root)/(routes)/market/components/products/bar-chart.tsx new file mode 100644 index 0000000..6fac7d7 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/products/bar-chart.tsx @@ -0,0 +1,55 @@ +"use client" + +import 'chart.js/auto'; +import { Company } from "@prisma/client" +import { Bar } from "react-chartjs-2" +import Utils from "react-chartjs-2" +interface BarChartProps { + company: Company +} +type chartData = { + date: string, + revenue: { + raw: Number + }, + earnings: { + raw: Number + } +} +export const BarChart = ({ company }: BarChartProps) => { + const DATA_COUNT = 7; + //const NUMBER_CFG = { count: DATA_COUNT, min: -100, max: 100 }; + const values: chartData[] = company.yahooStockV2Summary.earnings.financialsChart.quarterly + const labels = values.map((value) => { + const firstTwo = value.date.slice(0, 2) + const last = value.date.slice(-4) + return firstTwo + ' ' + last + }) + const revenueData = values.map((value) => value.revenue.raw) + const earningsData = values.map((value) => value.earnings.raw) + return ( +
+ +
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/products/featured-product.tsx b/infrastructure/application/app/(root)/(routes)/market/components/products/featured-product.tsx new file mode 100644 index 0000000..d7b7c62 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/products/featured-product.tsx @@ -0,0 +1,85 @@ +"use client" + +import { Company } from "@prisma/client"; +import { useAnimate } from "framer-motion"; +import Image from "next/image"; +import { useEffect, useState } from "react"; +import Wrapper from "../transaction/wrapper"; +import Heart from "./heart"; +import { ProductDialog } from "./product-dialog"; +import { ProgressBar } from "./progress-bar"; +import { useTicker } from "@/hooks/use-ticker"; +interface FeaturedProductProps { + company: Company, +} + +export const FeaturedProduct: React.FC = ({ company }) => { + const { ticker } = useTicker() + const [scope, animate] = useAnimate(); + const [scope2, animate2] = useAnimate(); + const [isSwapped, setIsSwapped] = useState(true) + + const companyName = company.yahooStockV2Summary.price.shortName + // animate product img + useEffect(() => { + if (isSwapped) { + animate(scope.current, { x: 150, scale: 0, opacity: 0 }, { duration: 0 }) + const timeout = setTimeout(() => { + animate2(scope2.current, { x: 0, scale: 1, opacity: 1 }, { duration: .4 }) + }, 50) + setIsSwapped(!isSwapped) + return () => clearTimeout(timeout); + } + + animate2(scope2.current, { x: 150, scale: 0, opacity: 0 }, { duration: 0 }) + const timeout2 = setTimeout(() => { + animate(scope.current, { x: 0, scale: 1, opacity: 1 }, { duration: .4 }) + }, 50) + setIsSwapped(!isSwapped) + return () => clearTimeout(timeout2); + }, [ticker]) + return ( +
+
+
+

+ {companyName} +

+ + + +
+
+

+ ${company.price} /ea +

+ +
+
+ +
+ Product Image + Product Image +
+
+ +
+
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/products/heart.tsx b/infrastructure/application/app/(root)/(routes)/market/components/products/heart.tsx new file mode 100644 index 0000000..01f9753 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/products/heart.tsx @@ -0,0 +1,114 @@ +"use client" + +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import axios from "axios"; +import { Check } from "lucide-react"; +import { useEffect, useState } from "react"; +import { AiOutlineHeart } from "react-icons/ai"; +import { IoHeart } from "react-icons/io5"; +import { useToast } from "@/components/shadcn-ui/use-toast" +import { useTicker } from "@/hooks/use-ticker"; +import { useQuery } from "@tanstack/react-query"; +import { Company } from "@prisma/client"; +import { Skeleton } from "@/components/shadcn-ui/skeleton"; +import { ThreeDots } from "react-loading-icons"; + + +const Heart = () => { + const { toast } = useToast() + const { ticker, setIsLiking } = useTicker() + + const { data, isLoading } = useQuery({ + queryKey: ['getWatchlist'], + queryFn: async () => { + const response = await axios.get('/api/watchlist') + return response.data; + }, + //staleTime: 3600000 // 1 hour in ms only runs once when the component mounts + }) + const likedSymbols = data?.map((item) => item.symbol) // ['AAPL','AMD'] + const isContained = likedSymbols?.includes(ticker) + + //const [isLiked, setIsLiked] = useState(false) + const [isDisabled, setIsDisabled] = useState(false) + const [isLiked, setIsLiked] = useState(isContained) + + useEffect(() => { + setIsLiked(isContained) + }, [ticker]) + + function handleClicked() { + setIsLiking(true) + setIsDisabled(true) + setIsLiked(!isContained) + updateWatchlist() + + } + const queryClient = useQueryClient() + const { mutate: updateWatchlist, isPending } = useMutation({ + mutationFn: () => { + return axios.patch("/api/watchlist", { + isLiked: !isContained, + ticker: ticker + }) + }, + onError: (error) => { + console.log(error) + }, + onSuccess: () => { + setIsLiking(false) + const description = !isContained ? "Added to watchlist" : "Removed from watchlist" + queryClient.invalidateQueries({ + queryKey: ['getWatchlist'], + exact: true, + refetchType: 'all' + }) + toast({ + className: "shadow-lg gap-2", + duration: 1500, + description, + action: , + }) + + } + + }) + + if (isPending) return ( +
+ +
+ ) + + if (isLoading) return ( + + ) + + return ( + + ) +} + +export default Heart \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/products/product-dialog.tsx b/infrastructure/application/app/(root)/(routes)/market/components/products/product-dialog.tsx new file mode 100644 index 0000000..15dfd7d --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/products/product-dialog.tsx @@ -0,0 +1,48 @@ +import { ExternalLink } from "lucide-react" +import { Button } from "@/components/shadcn-ui/button" +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/shadcn-ui/dialog" +import Image from "next/image" +import { Separator } from "@/components/shadcn-ui/separator" +import { BarChart } from "./bar-chart" +import { Company } from "@prisma/client" + +interface ProductDialogProps { + company: Company +} + +export const ProductDialog = ({ company }: ProductDialogProps) => { + + return ( + + + + + + + + {company.yahooStockV2Summary.price.shortName} Revenue/Earnings + + Earning charts helps you understand the performance of the company. + + + + + + ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/products/progress-bar.tsx b/infrastructure/application/app/(root)/(routes)/market/components/products/progress-bar.tsx new file mode 100644 index 0000000..25a8ad2 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/products/progress-bar.tsx @@ -0,0 +1,73 @@ +"use client" +import { Progress } from "@/components/shadcn-ui/progress" +import { Company } from "@prisma/client" +import { useEffect, useState } from "react" +import BigNumber from 'bignumber.js' + +interface ProgressBarProps { + company: Company +} + +const getPortion = (portion: number) => { + // raging from 0 to 1000 + if (portion > 0 && portion < 10) return 10; + if (portion > 10 && portion < 50) return 20; + if (portion > 50 && portion < 150) return 50; + if (portion >= 150 && portion < 300) return 56; + if (portion > 300 && portion < 600) return 70; + if (portion > 600 && portion < 900) return 85; + else return 90 +} + +const getPortion2 = (portion: number) => { + // raging from 0 to 1000 + if (portion > 0 && portion < 10) return 10; + if (portion > 10 && portion < 50) return 20; + if (portion > 50 && portion < 150) return 50; + if (portion >= 150 && portion < 300) return 56; + if (portion > 300 && portion < 600) return 70; + if (portion > 600 && portion < 900) return 85; + else return 90 +} + +export function ProgressBar({ company }: ProgressBarProps) { + const [progress, setProgress] = useState(0) + const [secondProgress, setSecondProgress] = useState(0) + const revenue = new BigNumber(company.yahooStockV2Summary.financialData.totalRevenue.raw) + const formatRevenue = revenue.toFormat(0, { groupSeparator: ',', groupSize: 3 }) + + const total = new BigNumber(1000000000) + const portion = revenue.dividedBy(total).toNumber() + + const marketCap = new BigNumber(company.yahooStockV2Summary.price.marketCap.raw) + const formatRevenue2 = marketCap.toFormat(0, { groupSeparator: ',', groupSize: 3 }) + const totalMarket = new BigNumber(100000000) + const portion2 = revenue.dividedBy(totalMarket).toNumber() + + useEffect(() => { + setProgress(getPortion(portion)); + setSecondProgress(getPortion2(portion2)) + }, [company]) + + return ( +
+
+

Market Capitalization

+

${formatRevenue2}

+
+ + +
+

2023 Revenue

+

${formatRevenue}

+
+ +
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/column-header.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/column-header.tsx new file mode 100644 index 0000000..fd56a18 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/column-header.tsx @@ -0,0 +1,71 @@ +import { + ArrowDownIcon, + ArrowUpIcon, + CaretSortIcon, + EyeNoneIcon, +} from "@radix-ui/react-icons" +import { Column } from "@tanstack/react-table" + +import { cn } from "@/lib/utils" +import { Button } from "@/components/shadcn-ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/shadcn-ui/dropdown-menu" + +interface DataTableColumnHeaderProps + extends React.HTMLAttributes { + column: Column + title: string +} + +export function DataTableColumnHeader({ + column, + title, + className, +}: DataTableColumnHeaderProps) { + if (!column.getCanSort()) { + return
{title}
+ } + + return ( +
+ + + + + + column.toggleSorting(false)}> + + Asc + + column.toggleSorting(true)}> + + Desc + + + column.toggleVisibility(false)}> + + Hide + + + +
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/columns.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/columns.tsx new file mode 100644 index 0000000..73fec23 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/columns.tsx @@ -0,0 +1,124 @@ +"use client" + +import Image from "next/image" +import { ColumnDef } from "@tanstack/react-table" + +import { useState } from "react" +import { cn } from "@/lib/utils" +import { DataTableColumnHeader } from "./column-header" +import { MoveRight } from "lucide-react" +import { useTicker } from "@/hooks/use-ticker" +import { useAnimation } from "@/hooks/use-animation" +import { Button } from "@/components/aceternity-ui/moving-border" + +// This type is used to define the shape of our data. +export type CompanyDef = { + symbol: string, + sector: string, + trend: string, + price: number, + percentChg: number +} + +export const columns: ColumnDef[] = [ + { + accessorKey: "symbol", + header: ({ column }) => ( +
+ ), + cell: ({ row }) => { + + const data: string = row.getValue("symbol") + const imgPath = "/logos/" + data.toLowerCase() + ".svg"; + return ( +
+ Company logo + {data} +
+ ) + } + }, + { + accessorKey: "sector", + header: ({ column }) => ( +
+ ), + cell: ({ row }) => { + return
{row.getValue("sector")}
+ }, + }, + { + accessorKey: "trend", + header: ({ column }) => ( +
+ ), + cell: ({ row }) => { + const data = row.getValue("trend") + return ( +
+ {row.getValue("trend")} +
+ ) + }, + }, + { + accessorKey: "price", + header: ({ column }) => ( +
+ ), + cell: ({ row }) => { + const amount = parseFloat(row.getValue("price")) + const formatted = new Intl.NumberFormat("en-US", { + style: "currency", + currency: "USD", + }).format(amount) + + return
{formatted}
+ }, + + }, + { + accessorKey: "percentChg", + header: ({ column }) => ( +
+ ), + cell: ({ row }) => { + const data = parseFloat(row.getValue("percentChg")) + const { animatedId, setAnimatedId } = useAnimation(); + + const { ticker } = useTicker() + const style = data < 0 ? "text-red-400" : "text-green-400"; + return ( +
+

+ {data < 0 ? data.toFixed(2) + "%" : '+' + data.toFixed(2) + '%'} +

+ {ticker === row.getValue("symbol") && } +
+ + ) + + + } + } +] diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/data-table.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/data-table.tsx new file mode 100644 index 0000000..3b11628 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/data-table.tsx @@ -0,0 +1,203 @@ +"use client" +import * as React from "react" +import { + ColumnDef, + ColumnFiltersState, + VisibilityState, + SortingState, + + flexRender, + getCoreRowModel, + getFilteredRowModel, + getSortedRowModel, + getPaginationRowModel, + + useReactTable, +} from "@tanstack/react-table" +import { useToast } from "@/components/shadcn-ui/use-toast" +import { Row } from "@tanstack/react-table" +import { + DropdownMenu, + DropdownMenuCheckboxItem, + DropdownMenuContent, + DropdownMenuTrigger, +} from "@/components/shadcn-ui/dropdown-menu" +import { DataTablePagination } from "./table-pagination" +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/shadcn-ui/table" + +interface DataTableProps { + columns: ColumnDef[] + data: TData[] +} + +interface TickerProps { + ticker: string, +} + +import { Input } from "@/components/shadcn-ui/input" +import { Button } from "@/components/shadcn-ui/button" +import { useTicker } from "@/hooks/use-ticker" +import { SlidersHorizontal } from "lucide-react" + +export function DataTable({ + columns, + data +}: DataTableProps) { + const { toast } = useToast() + const [sorting, setSorting] = React.useState([]) + const [columnFilters, setColumnFilters] = React.useState([]) + const [columnVisibility, setColumnVisibility] = React.useState({}) + + const table = useReactTable({ + data, + columns, + getCoreRowModel: getCoreRowModel(), + onColumnFiltersChange: setColumnFilters, + getFilteredRowModel: getFilteredRowModel(), + onColumnVisibilityChange: setColumnVisibility, + onSortingChange: setSorting, + getSortedRowModel: getSortedRowModel(), + getPaginationRowModel: getPaginationRowModel(), + initialState: { + pagination: { + pageSize: 50, + } + }, + autoResetAll: false, + autoResetPageIndex: false, + + state: { + columnFilters, + columnVisibility, + sorting + }, + }) + + const { ticker, setTicker, isLiking } = useTicker() + const handleClick = (symbol: string) => { + if (isLiking) return ( + toast({ + variant: "destructive", + title: "Please wait.", + }) + ); + setTicker(symbol) + } + + const getSymbol = (row: Row) => { + const data = row.original as { symbol: string } + return data.symbol + } + return ( + <> +
+ + table.getColumn("symbol")?.setFilterValue(event.target.value) + } + className="max-w-[200px]" + /> + + + + + + {table + .getAllColumns() + .filter( + (column) => column.getCanHide() + ) + .map((column) => { + return ( + + column.toggleVisibility(!!value) + } + > + {column.id} + + ) + })} + + +
+
+ + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => { + return ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext() + )} + + ) + })} + + ))} + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + handleClick(getSymbol(row))} + className={` + border-t + border-muted-foreground/20 + text-xs + font-medium + hover:bg-indigo-600/10 + hover:cursor-pointer + transition duration-75 + ${getSymbol(row) === ticker && 'ring-2 ring-cyan-600 ring-inset'} + `} + > + {row.getVisibleCells().map((cell) => ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + )) + ) : ( + + + No results. + + + )} + +
+
+ + + ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/search-input.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/search-input.tsx new file mode 100644 index 0000000..a367866 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/search-input.tsx @@ -0,0 +1,45 @@ +import { Input } from "@/components/shadcn-ui/input" +import { Search } from "lucide-react" +import { useEffect, useState, ChangeEventHandler } from "react"; +import { useRouter, useSearchParams } from "next/navigation"; +import qs from "query-string"; + + + +export const SearchInput = () => { + + const router = useRouter(); + const searchParams = useSearchParams(); + + const symbol = searchParams.get("symbol"); + + const [value, setValue] = useState(symbol || ""); + + const onChange: ChangeEventHandler = (e) => { + setValue(e.target.value); + }; + + useEffect(() => { + const query = { + symbol: value + }; + + const url = qs.stringifyUrl({ + url: window.location.href, + query + }, { skipNull: true, skipEmptyString: true }); + + router.push(url); + }, [value, router]) + + return ( +
+ +
+ ) +} \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/stock-list.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/stock-list.tsx new file mode 100644 index 0000000..05d4de9 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/stock-list.tsx @@ -0,0 +1,198 @@ +"use client" + +import Image from "next/image"; + +import { SearchInput } from "./search-input"; +import { MdCorporateFare } from "react-icons/md"; +import { LiaIndustrySolid } from "react-icons/lia"; +import { MdCurrencyExchange } from "react-icons/md"; +import { BsPeople } from "react-icons/bs"; +import { MoveRight } from "lucide-react"; + +import { + Table, + TableBody, + TableCaption, + TableCell, + TableHead, + TableHeader, + TableRow +} from "@/components/shadcn-ui/table"; +import { Company } from "@prisma/client"; + +interface StockListProps { + ticker: string + setTicker: (ticker: string) => void + companies: Company[] + animatedClick: () => void; + searchSymbol: string +} + + +const getSentimentalColor = (input: string) => { + switch (input) { + case "strong buy": + return "text-green-500"; + case "buy": + return "text-green-600"; + case "sell": + return "text-red-600"; + case "strong sell": + return "text-red-500"; + default: + return "text-blue-500"; + } +}; + +const getChangeFormat = (percent: number) => { + if (percent < 0) return percent.toFixed(2) + '%' + return '+' + percent.toFixed(2) + '%' +} + +export const StockList = ({ ticker, setTicker, companies, animatedClick }: StockListProps) => { + + + const handleClick = (company: Company) => { + // Call back func + setTicker(company.symbol); + } + + return ( +
+
+ +
+

Results

+
+
+ + +

A list of top ranked companies.

+
+ + + {/* +
+ +

Rank

+
+
*/} + +
+ +

Company

+
+
+ +
+ +

Sector

+
+
+ +
+ +

Trend

+
+
+ +
+ +

Price

+
+
+ +
+ +

Chg

+
+
+ +
+
+ + {companies + .map((item: Company, index: number) => ( + handleClick(item)} + className={` + border-t-2 + border-t-muted-foreground/20 + text-xs + font-medium + hover:bg-indigo-600/10 + hover:cursor-pointer + ${item.symbol === ticker && 'border-2 border-indigo-600'} + `} + > + {/* +
+ {item.rank} +
+
*/} + +
+
+ +
+ {item.symbol} +
+
+ +

{item.yahooStockV2Summary.summaryProfile.sector}

+
+ + {item.yahooStockV2Summary.financialData.recommendationKey} + + + ${item.price} + + 0 ? "text-green-500" : "text-red-500"}`}> + {getChangeFormat(item.yahooMarketV2Data.regularMarketChangePercent)} + + + {item.symbol === ticker && + + } + +
+ ))} +
+
+
+ + ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/table-container.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/table-container.tsx new file mode 100644 index 0000000..e552312 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/table-container.tsx @@ -0,0 +1,26 @@ + +import { DataTable } from "./data-table" +import { columns, CompanyDef } from "./columns" +import { Company } from "@prisma/client" + +interface TableContainerProps { + companies: Company[], +} + +const TableContainer = ({ companies }: TableContainerProps) => { + const data = companies.map((item) => ( + { + symbol: item.symbol, + sector: item.yahooStockV2Summary.summaryProfile.sector, + trend: item.yahooStockV2Summary.financialData.recommendationKey, + price: item.price, + percentChg: item.yahooStockV2Summary.price.regularMarketChange.fmt + } + )) + + return ( + + ) +} + +export default TableContainer \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/components/table/table-pagination.tsx b/infrastructure/application/app/(root)/(routes)/market/components/table/table-pagination.tsx new file mode 100644 index 0000000..6805657 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/table/table-pagination.tsx @@ -0,0 +1,96 @@ +import { + ChevronLeftIcon, + ChevronRightIcon, + DoubleArrowLeftIcon, + DoubleArrowRightIcon, +} from "@radix-ui/react-icons" +import { Table } from "@tanstack/react-table" + +import { Button } from "@/components/shadcn-ui/button" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/shadcn-ui/select" + +interface DataTablePaginationProps { + table: Table +} + +export function DataTablePagination({ + table, +}: DataTablePaginationProps) { + return ( +
+
+ {table.getFilteredRowModel().rows.length} row(s). +
+
+
+

Rows per page

+ +
+
+ Page {table.getState().pagination.pageIndex + 1} of{" "} + {table.getPageCount()} +
+
+ + + + +
+
+
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/transaction/buy-transaction.tsx b/infrastructure/application/app/(root)/(routes)/market/components/transaction/buy-transaction.tsx new file mode 100644 index 0000000..fbcb268 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/transaction/buy-transaction.tsx @@ -0,0 +1,224 @@ +"use client" + +import * as z from "zod"; +import { Label } from "@/components/shadcn-ui/label" +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetOverlay, + SheetTitle, + SheetTrigger, +} from "@/components/shadcn-ui/sheet" +import { Account, Company } from "@prisma/client" +import axios from "axios"; +//import { safeParse } from 'zod' + +import * as React from "react" + +import { Form, FormControl, FormField, FormItem, FormMessage } from "@/components/shadcn-ui/form"; +import { useForm } from "react-hook-form"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { Input } from "@/components/shadcn-ui/input"; +import { useState, useEffect } from "react"; +import toast from "react-hot-toast"; +import { useMutation, QueryClient, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Separator } from "@/components/shadcn-ui/separator"; +interface TransactionProps { + company: Company, +} + +type Transaction = { + value: number, + symbol: string +} + +export function BuyTransaction({ company }: TransactionProps) { + + const client = useQueryClient(); + const data: Account = client.getQueryData(['getAccount']) + const symbol = company.symbol; + + const formSchema = z.object({ + prompt: z.coerce.number().int() + .positive() + .gte(50, { + message: "Minimum $50 is required." + }) + .lte(data.accountBalance, { + message: `Must be less than or equal to your account balance. Your current account balance: ${data.accountBalance}`, + }) + }); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + prompt: 50 + } + }); + + const onSubmit = async (values: z.infer) => { + try { + const transaction = { + value: values.prompt, + symbol + } + updatePortfolio(transaction) + // clear user input + } catch (error) { + console.log("Error submitting the form") + } finally { + //router.refresh() + } + } + + const queryClient = useQueryClient(); + const { mutate: updatePortfolio, isPending } = useMutation({ + mutationFn: (transaction: Transaction) => { + return axios.patch("/api/transaction/buy", { transaction }) + }, + onError: (error) => { + console.log(error) + }, + onSuccess: () => { + toast.success("BUY Successful") + setIsOpen(false) + form.reset() + setShareLabel('0') + queryClient.invalidateQueries({ + queryKey: ['getAccount'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getPortfolio'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getAccount2'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getTransaction'], + exact: true, + refetchType: 'all' + }) + } + }) + + const [isOpen, setIsOpen] = React.useState(false) + const isLoading = form.formState.isSubmitting; + const [shareLabel, setShareLabel] = useState("") + return ( + + + + + setIsOpen(false)} + /> + + + Buy Stock + + Make stock transactions here + + +
+
+ + +
+
+ + +
+ +
+ + +
+ +
+ +
+

Amount

+ ( + + + { + field.onChange(e) + setShareLabel(`${parseFloat(e.target.value) ? parseFloat(e.target.value) / company.price : 0}`); + }} + /> + + + + )} + /> +

$

+
+ {shareLabel !== '0' && +
+ + +

Shares

+ +
+ } +
+ + + +
+
+ +
+
+
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/transaction/sell-transaction.tsx b/infrastructure/application/app/(root)/(routes)/market/components/transaction/sell-transaction.tsx new file mode 100644 index 0000000..cc81e91 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/transaction/sell-transaction.tsx @@ -0,0 +1,229 @@ +"use client" + +import { Label } from "@/components/shadcn-ui/label"; +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetOverlay, + SheetTitle, + SheetTrigger, +} from "@/components/shadcn-ui/sheet"; +import { Account, Company, Portfolio, Portfolio_Company } from "@prisma/client"; +import axios from "axios"; +import { useRouter } from "next/navigation"; +import * as z from "zod"; +//import { safeParse } from 'zod' + +import * as React from "react"; + +import { Form, FormControl, FormField, FormItem, FormMessage } from "@/components/shadcn-ui/form"; +import { Input } from "@/components/shadcn-ui/input"; +import { zodResolver } from "@hookform/resolvers/zod"; +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { useState } from "react"; +import { useForm } from "react-hook-form"; +import toast from "react-hot-toast"; +interface TransactionProps { + company: Company, +} + +type Transaction = { + value: number, + symbol: string +} + +export function SellTransaction({ company }: TransactionProps) { + + const client = useQueryClient(); + const data: Account & Portfolio & Portfolio_Company[] = client.getQueryData(['getAccount']) + const symbol = company.symbol; + const portfolioStocks: Portfolio_Company[] = data.portfolio.companies + + const foundStock = portfolioStocks.filter((item) => item.symbol.includes(company.symbol)) + const foundStockPrice = foundStock.length === 0 ? 0 : foundStock[0].shares * company.price + + const shareAmount = foundStock.length === 0 ? 0 : foundStock[0].shares + const price = portfolioStocks.length === 0 ? 0 : foundStockPrice + + const formSchema = z.object({ + prompt: z.coerce.number().int() + .positive() + .gt(0, { + message: "Must be greater than $0.00" + }) + .lte(price, { + message: `You currently have ${shareAmount} shares. Not enough ${company.symbol} stocks` + }) + }); + + const form = useForm>({ + resolver: zodResolver(formSchema), + defaultValues: { + prompt: 0 + } + }); + + const onSubmit = async (values: z.infer) => { + try { + const transaction = { + value: values.prompt, + symbol + } + updatePortfolio(transaction) + } catch (error) { + console.log("Error submitting the form") + } + } + + const queryClient = useQueryClient(); + const { mutate: updatePortfolio, isPending } = useMutation({ + mutationFn: (transaction: Transaction) => { + return axios.patch("/api/transaction/sell", { transaction }) + }, + onError: (error) => { + console.log(error) + }, + onSuccess: () => { + toast.success("SELL Successful") + setIsOpen(false) + form.reset() + setShareLabel('0') + queryClient.invalidateQueries({ + queryKey: ['getAccount'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getPortfolio'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getAccount2'], + exact: true, + refetchType: 'all' + }) + queryClient.invalidateQueries({ + queryKey: ['getTransaction'], + exact: true, + refetchType: 'all' + }) + } + }) + + const [isOpen, setIsOpen] = React.useState(false) + const isLoading = form.formState.isSubmitting; + const router = useRouter() + + const [shareLabel, setShareLabel] = useState("") + return ( + + + + + setIsOpen(false)} + /> + + + Sell Stock + + Make stock transactions here + + +
+
+ + +
+
+ + +
+ +
+ + +
+ +
+ +
+

Amount

+ ( + + + { + field.onChange(e) + setShareLabel(`${parseFloat(e.target.value) ? parseFloat(e.target.value) / company.price : 0}`); + }} + /> + + + + )} + /> +

$

+
+ {shareLabel !== '0' && +
+ + +

Shares

+ +
+ } +
+ + + +
+
+ +
+
+
+ ) +} diff --git a/infrastructure/application/app/(root)/(routes)/market/components/transaction/wrapper.tsx b/infrastructure/application/app/(root)/(routes)/market/components/transaction/wrapper.tsx new file mode 100644 index 0000000..00b3148 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/components/transaction/wrapper.tsx @@ -0,0 +1,39 @@ +import { Company } from "@prisma/client" +import { BuyTransaction } from "./buy-transaction" +import { SellTransaction } from "./sell-transaction" +import { useQuery } from "@tanstack/react-query" +import axios from "axios" +import { useTicker } from "@/hooks/use-ticker" +import { Skeleton } from "@/components/shadcn-ui/skeleton" +interface WrapperInterface { + company: Company +} + +const Wrapper = ({ + company +}: WrapperInterface) => { + const { ticker } = useTicker() + const { isLoading } = useQuery({ + queryKey: ['getAccount'], + queryFn: async () => { + const response = await axios.get('/api/account') + return response.data; + }, + }) + if (isLoading) return ( +
+ + +
+ ) + return ( + <> +
+ + +
+ + ) +} + +export default Wrapper \ No newline at end of file diff --git a/infrastructure/application/app/(root)/(routes)/market/page.tsx b/infrastructure/application/app/(root)/(routes)/market/page.tsx new file mode 100644 index 0000000..0ae55e4 --- /dev/null +++ b/infrastructure/application/app/(root)/(routes)/market/page.tsx @@ -0,0 +1,27 @@ +import prismadb from "@/lib/prismadb" +import MarketContainer from "./components/market-container" +import { auth } from "@clerk/nextjs" +import { redirect } from "next/navigation" + +interface MarketProps { + searchParams: { + symbol: string + } +} + +const Market = async () => { + const { userId } = auth() + if (!userId) { + redirect("/") + } + // fetching + const companies = await prismadb.company.findMany() + + return ( + + ) +} + +export default Market \ No newline at end of file diff --git a/infrastructure/application/app/(root)/layout.tsx b/infrastructure/application/app/(root)/layout.tsx new file mode 100644 index 0000000..e1b66fa --- /dev/null +++ b/infrastructure/application/app/(root)/layout.tsx @@ -0,0 +1,21 @@ +import { Navbar } from "@/components/app/navbar" +import Sidebar from "@/components/app/sidebar" + +const RootLayout = ({ children }: { children: React.ReactNode }) => { + return ( +
+ +
+ +
+ +
+ {children} +
+
+ ) +} + +export default RootLayout \ No newline at end of file diff --git a/infrastructure/application/app/api/account/route.ts b/infrastructure/application/app/api/account/route.ts new file mode 100644 index 0000000..89e4893 --- /dev/null +++ b/infrastructure/application/app/api/account/route.ts @@ -0,0 +1,33 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; + +export async function GET() { + try { + const { userId } = auth() + if (!userId) { + return new NextResponse("Unauthorized", { status: 401 }) + } + const account = await prismadb.account.findFirst({ + where: { + id: userId + }, + include: { + portfolio: { + include: { + companies: true + + } + } + } + } + ) + //console.log(account); + console.log("account GET route runs") + return NextResponse.json(account, { status: 200 }) + } catch (error) { + console.log(error) + return new NextResponse("Internal Error", { status: 500 }) + } +} + diff --git a/infrastructure/application/app/api/admin/authenticate/route.ts b/infrastructure/application/app/api/admin/authenticate/route.ts new file mode 100644 index 0000000..c16c72e --- /dev/null +++ b/infrastructure/application/app/api/admin/authenticate/route.ts @@ -0,0 +1,13 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; + +export async function POST(request: Request) { + const body = await request.json(); + const { values } = body; + console.log(values) + if (values.username !== '12345') + return new Response("InCorrect Password!", { status: 400 }); + + return new NextResponse("Correct key", { status: 200 }); +} \ No newline at end of file diff --git a/infrastructure/application/app/api/admin/route.ts b/infrastructure/application/app/api/admin/route.ts new file mode 100644 index 0000000..ae65109 --- /dev/null +++ b/infrastructure/application/app/api/admin/route.ts @@ -0,0 +1,31 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; + +export async function GET() { + try { + const result = [] + const accounts = await prismadb.account.findMany({ + include: { + portfolio: { + include: { + companies: true + } + }, + watchlist: { + include: { + companies: true + } + }, + cards: true, + transactions: true + } + }); + + console.log(accounts) + return NextResponse.json(accounts, { status: 200 }); + } catch (error) { + console.log(error) + return new NextResponse("Internal error", { status: 500 }); + } +} \ No newline at end of file diff --git a/infrastructure/application/app/api/card/add/route.ts b/infrastructure/application/app/api/card/add/route.ts new file mode 100644 index 0000000..564918b --- /dev/null +++ b/infrastructure/application/app/api/card/add/route.ts @@ -0,0 +1,27 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; +import { Portfolio_Company, Watchlist_Company, Company } from "@prisma/client"; + +export async function POST(request: Request) { + const { userId } = auth() + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + const { data, color } = await request.json() + const { name, cardNumber, value, expiryDate } = data + await prismadb.card.create({ + data: { + holderName: name, + value: value, + expiration: expiryDate, + cardDigits: cardNumber, + type: "VISA", + accountId: userId, + color + }, + }) + + + console.log("card/add POST runs") + return NextResponse.json("", { status: 200 }); +} \ No newline at end of file diff --git a/infrastructure/application/app/api/card/deposit/route.ts b/infrastructure/application/app/api/card/deposit/route.ts new file mode 100644 index 0000000..bb29e04 --- /dev/null +++ b/infrastructure/application/app/api/card/deposit/route.ts @@ -0,0 +1,68 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; + + +export async function PATCH(request: Request) { + try { + const { userId } = auth() + console.log(userId) + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + + const { cardNum, value } = await request.json() + + const foundCard = await prismadb.card.findFirst( + { where: { cardDigits: cardNum } } + ) + if (value > foundCard.value) { + await prismadb.account.update({ + where: { + id: userId + }, + data: { + transactions: { + create: { + type: "deposit", + amount: value, + status: 'failed' + } + } + } + }) + return new NextResponse("Insufficient balance", { status: 400 }) + } + console.log(foundCard) + await prismadb.card.update({ + where: { + cardDigits: cardNum + }, + data: { + value: foundCard.value - value + } + }) + + const foundAccount = await prismadb.account.findFirst({ where: { id: userId } }) + await prismadb.account.update({ + where: { + id: userId + }, + data: { + accountBalance: foundAccount.accountBalance + value, + accountValue: foundAccount.accountValue + value, + transactions: { + create: { + type: `deposit`, + status: 'success', + amount: value, + } + } + } + }) + console.log("card/deposit PATCH runs") + return new NextResponse("Success", { status: 200 }) + } catch (error) { + console.log(error); + return new NextResponse("Internal Error", { status: 500 }) + } +} \ No newline at end of file diff --git a/infrastructure/application/app/api/card/remove/route.ts b/infrastructure/application/app/api/card/remove/route.ts new file mode 100644 index 0000000..c1442f5 --- /dev/null +++ b/infrastructure/application/app/api/card/remove/route.ts @@ -0,0 +1,20 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; + +export async function POST(request: Request) { + const { userId } = auth() + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + const body = await request.json() + console.log(body) + await prismadb.card.delete({ + where: { + cardDigits: body.cardNum + } + }) + + + console.log("card/remove POST route runs") + return new NextResponse("Delete card successfully", { status: 200 }); +} \ No newline at end of file diff --git a/infrastructure/application/app/api/card/route.ts b/infrastructure/application/app/api/card/route.ts new file mode 100644 index 0000000..f2f29b8 --- /dev/null +++ b/infrastructure/application/app/api/card/route.ts @@ -0,0 +1,21 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; + +export async function GET() { + const { userId } = auth() + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + + const result = await prismadb.account.findFirst({ + where: { + id: userId + }, + include: { + cards: true + } + }) + + console.log("card GET route runs") + return NextResponse.json(result.cards, { status: 200 }); +} \ No newline at end of file diff --git a/infrastructure/application/app/api/card/withdraw/route.ts b/infrastructure/application/app/api/card/withdraw/route.ts new file mode 100644 index 0000000..237f079 --- /dev/null +++ b/infrastructure/application/app/api/card/withdraw/route.ts @@ -0,0 +1,67 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; + + +export async function PATCH(request: Request) { + try { + const { userId } = auth() + console.log(userId) + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + + const { cardNum, value } = await request.json() + + const foundCard = await prismadb.card.findFirst({ where: { cardDigits: cardNum } }) + const foundAccount = await prismadb.account.findFirst({ where: { id: userId } }) + + if (value > foundAccount.accountBalance) { + await prismadb.account.update({ + where: { + id: userId + }, + data: { + transactions: { + create: { + type: "withdraw", + amount: value, + status: 'failed' + } + } + } + }) + return new NextResponse("Insufficient account balance", { status: 400 }) + } + + await prismadb.card.update({ + where: { + cardDigits: cardNum + }, + data: { + value: foundCard.value + value + } + }) + + await prismadb.account.update({ + where: { + id: userId + }, + data: { + accountBalance: foundAccount.accountBalance - value, + accountValue: foundAccount.accountValue - value, + transactions: { + create: { + type: `withdraw`, + status: 'success', + amount: value, + } + } + } + }) + console.log("card/deposit PATCH runs") + return new NextResponse("Success", { status: 200 }) + } catch (error) { + console.log(error); + return new NextResponse("Internal Error", { status: 500 }) + } +} \ No newline at end of file diff --git a/infrastructure/application/app/api/chat/route.ts b/infrastructure/application/app/api/chat/route.ts new file mode 100644 index 0000000..bb993c9 --- /dev/null +++ b/infrastructure/application/app/api/chat/route.ts @@ -0,0 +1,52 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { Configuration, OpenAIApi } from "openai"; + +const configuration = new Configuration({ + apiKey: process.env.OPENAI_API_KEY, +}); + +const openai = new OpenAIApi(configuration); + +export async function POST( + request: Request +) { + try { + + const { userId } = auth(); + const { messages } = await request.json(); + + + if (!userId) { + return new NextResponse("Unauthorized", { status: 401 }); + } + + if (!configuration.apiKey) { + return new NextResponse("OpenAI API Key not configured.", { status: 500 }); + } + + if (!messages) { + return new NextResponse("Messages are required", { status: 400 }); + } + const systemMessage = { + role: "system", + content: "you are a financial specialist who works at Dynamite Trade that helps clients with their investment decisions. YOU ONLY NEED TO ANSWER QUESTIONS THAT ARE related to financial aspect. If the user asks something else, please just reply you're only capable of answering financial related question. Additional infor regarding your working company, Dynamite Trade is a stock trading platform that allows customers to buy and sell stocks, view in-depth stock data, charts and AI integrated. You have in-depth knowledge regarding company overview in nasdaq or in the US market. ", + } + + const messageArray = [] + messageArray.push(systemMessage, messages) + console.log(messageArray) + const response = await openai.createChatCompletion({ + model: "gpt-3.5-turbo", + messages: messageArray, + max_tokens: 80 + }); + messageArray.length = 0; + const formatMessage = response.data.choices[0].message + return NextResponse.json(formatMessage); + + } catch (error) { + return new NextResponse("Internal Error", { status: 500 }); + } +} diff --git a/infrastructure/application/app/api/portfolio/route.ts b/infrastructure/application/app/api/portfolio/route.ts new file mode 100644 index 0000000..5154991 --- /dev/null +++ b/infrastructure/application/app/api/portfolio/route.ts @@ -0,0 +1,30 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; +import { Portfolio_Company, Watchlist_Company, Company } from "@prisma/client"; + + + +export async function GET() { + const { userId } = auth() + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + + const portfolio = await prismadb.portfolio.findFirst({ + where: { + accountId: userId + }, + include: { + companies: { + include: { + company: true + } + } + } + }) + console.log(portfolio); + const portfolioStocks = portfolio.companies.map((item: Portfolio_Company & Company) => item.company); + console.log("portfolio GET route runs") + return NextResponse.json(portfolio, { status: 200 }); + +} \ No newline at end of file diff --git a/infrastructure/application/app/api/transaction/buy/route.ts b/infrastructure/application/app/api/transaction/buy/route.ts new file mode 100644 index 0000000..054d278 --- /dev/null +++ b/infrastructure/application/app/api/transaction/buy/route.ts @@ -0,0 +1,99 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; + +export async function PATCH(request: Request) { + try { + const { transaction } = await request.json() + const { value, symbol } = transaction + + const { userId } = auth() + if (!userId) { + return new NextResponse("Unauthorized", { status: 401 }); + } + + const promises = [ + prismadb.portfolio.findFirst({ + where: { + accountId: userId + } + }), + prismadb.company.findFirst({ + where: { + symbol + } + }), + prismadb.account.findFirst({ + where: { + id: userId + } + }) + ] + const [portfolio, company, account] = await Promise.all(promises) + + const portfolioShares = await prismadb.portfolio_Company.findFirst({ + where: { + portfolioId: portfolio.id, + companyId: company.id, + }, + select: { + shares: true + } + }) + + const temp = portfolioShares ? portfolioShares.shares : 0 + + const allPromies = [ + prismadb.portfolio_Company.upsert({ + where: { + portfolioId_companyId: { + portfolioId: portfolio.id, + companyId: company.id, + }, + }, + update: { + shares: temp + value / company.price + }, + create: { + portfolioId: portfolio.id, + companyId: company.id, + symbol, + shares: value / company.price + }, + }), + prismadb.portfolio.update({ + where: { + accountId: userId + }, + data: { + portfolioVal: portfolio.portfolioVal + value + } + }), + + prismadb.account.update({ + where: { + id: userId + }, + data: { + accountBalance: account.accountBalance - value, + transactions: { + create: { + type: `buy ${symbol}`, + status: 'success', + amount: value, + } + } + } + }), + + ] + await Promise.all(allPromies); + + + console.log("api/transaction/buy PATCH runs") + return new NextResponse("Success", { status: 200 }) + } catch (error) { + console.log(error); + return new NextResponse("Internal Error", { status: 500 }) + } +} \ No newline at end of file diff --git a/infrastructure/application/app/api/transaction/route.ts b/infrastructure/application/app/api/transaction/route.ts new file mode 100644 index 0000000..10fe789 --- /dev/null +++ b/infrastructure/application/app/api/transaction/route.ts @@ -0,0 +1,33 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; + +export async function GET() { + try { + const { userId } = auth() + if (!userId) { + return new NextResponse("Unauthorized", { status: 401 }) + } + + const account = await prismadb.account.findFirst({ + where: { + id: userId + }, + include: { + transactions: { + orderBy: { + createdAt: 'desc' + } + }, + + } + } + ) + console.log("api/transaction GET route runs") + // return Transaction[] + return NextResponse.json(account.transactions, { status: 200 }) + } catch (error) { + console.log(error) + return new NextResponse("Internal Error", { status: 500 }) + } +} \ No newline at end of file diff --git a/infrastructure/application/app/api/transaction/sell/route.ts b/infrastructure/application/app/api/transaction/sell/route.ts new file mode 100644 index 0000000..dd76327 --- /dev/null +++ b/infrastructure/application/app/api/transaction/sell/route.ts @@ -0,0 +1,109 @@ +import { auth } from "@clerk/nextjs"; +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; + +export async function PATCH(request: Request) { + try { + const { transaction } = await request.json() + const { value, symbol } = transaction + + const { userId } = auth() + if (!userId) { + return new NextResponse("Unauthorized", { status: 401 }); + } + + const promises = [ + prismadb.portfolio.findFirst({ + where: { + accountId: userId + } + }), + prismadb.company.findFirst({ + where: { + symbol + } + }), + prismadb.account.findFirst({ + where: { + id: userId + } + }) + ] + const [portfolio, company, account] = await Promise.all(promises) + + await prismadb.account.update({ + where: { + id: userId + }, + data: { + accountBalance: account.accountBalance + value, + accountValue: account.accountValue + value + } + }) + await prismadb.portfolio.update({ + where: { + accountId: userId + }, + data: { + portfolioVal: portfolio.portfolioVal - value + } + }) + + const portfolioShares = await prismadb.portfolio_Company.findFirst({ + where: { + portfolioId: portfolio.id, + companyId: company.id, + }, + select: { + shares: true + } + }) + + + await prismadb.portfolio_Company.update({ + where: { + portfolioId_companyId: { + portfolioId: portfolio.id, + companyId: company.id + } + }, + data: { + shares: portfolioShares.shares - value / company.price + } + }) + + if (portfolioShares.shares - value / company.price <= 0.1) { + await prismadb.portfolio_Company.delete({ + where: { + portfolioId_companyId: { + portfolioId: portfolio.id, + companyId: company.id + } + } + }) + } + + + + + await prismadb.account.update({ + where: { + id: userId + }, + data: { + transactions: { + create: { + type: `sell ${symbol}`, + status: 'success', + amount: value, + } + } + } + }), + console.log("api/transaction/sell PATCH runs") + return new NextResponse("Success", { status: 200 }) + } catch (error) { + console.log(error); + return new NextResponse("Internal Error", { status: 500 }) + } +} \ No newline at end of file diff --git a/infrastructure/application/app/api/watchlist/route.ts b/infrastructure/application/app/api/watchlist/route.ts new file mode 100644 index 0000000..1d4401d --- /dev/null +++ b/infrastructure/application/app/api/watchlist/route.ts @@ -0,0 +1,67 @@ +import { NextResponse } from "next/server"; +import prismadb from "@/lib/prismadb"; +import { auth } from "@clerk/nextjs"; +import { Watchlist_Company, Company } from "@prisma/client"; + +export async function GET() { + const { userId } = auth() + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + + const watchlist = await prismadb.watchlist.findFirst({ + where: { + accountId: userId + }, + include: { + companies: { + include: { + company: true + } + } + } + }) + + const watchlistStocks = watchlist.companies.map((item: Watchlist_Company & Company) => item.company); + console.log("/api/watchlist GET runs") + return NextResponse.json(watchlistStocks, { status: 200 }); +} + +export async function PATCH(request: Request) { + const { isLiked, ticker } = await request.json() + const { userId } = auth() + if (!userId) + return new NextResponse("Unauthorized", { status: 401 }); + + const watchlist = await prismadb.watchlist.findFirst({ + where: { + accountId: userId, + } + }) + const company = await prismadb.company.findFirst({ + where: { + symbol: ticker + } + }) + + if (isLiked) { + await prismadb.watchlist_Company.create({ + data: { + watchlistId: watchlist.id, + companyId: company.id, + symbol: ticker + } + }) + } else { + await prismadb.watchlist_Company.delete({ + where: { + watchlistId_companyId: { + watchlistId: watchlist.id, + companyId: company.id + } + } + }) + } + + console.log("watchlist PATCH runs") + return new NextResponse("PATCH watchlist successful", { status: 200 }); +} \ No newline at end of file diff --git a/infrastructure/application/app/favicon.ico b/infrastructure/application/app/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..299db06d1c49339e8b22d094f01916b0bae9a2fc GIT binary patch literal 3230 zcmbuBdrXye7{{N3I8RPIW15<*F#MxNH&?(ax5%{`YyFc>S9Y`JW<_OLj#>>^HrA>& z63Zr25auYT)e#t?5d;T|bm~CRr>k)$?)*7)YsQr@%sWkbHK`g!&^f`Lr?Vg_dn3r*B8{= z+xz6$*w`s-euP6FHmT%gyW)^1Z{e=7o0^&u{IAF7F^zR{a`G9*_b~ntEe7L+I}oK6N%P zIMiboL_YQ`EG#(PdN`bc)1%yfgUbhXb#=b*^aOcjJ~A>=8X6i_vi?itG&!^l4h{;w zc7o$_AH; z2=)i@brkGi#lVM)Jp4irnZ3-sP?NRjautKPwGYza8AbjpRJZuTR|fgIz@7~Cgm1A} z^b9+Ozt_3H;`RfFVf-9oF9ndF5b}Bko!RqnoCLRBdAjZ4KD-tdHPv)L@zo+ei> zaz6;x*i!#D7-Kr$0p}B}eJD6M*aDV?nprtJLe(cR9q@#&^W^z!`t=a|#g|;{rp^wH z{RX(Y=cdT*dHf#+`}wA(rfu}@W^{YulSLi#@t4E+S#R*c+|Qb%f`S4c%^m%|9zSOl zpBP7&`yD>!{#PB(8vD5Fuxp3^uhgK1@tn@<{?j$&|6XtKu?wbO{@{M6o}0v_@(4cL z@HxZt+sr4bZ%>zUBKN1j|A}*DwKw>+wY7S_BofC*YJdFrammcgTtRGRKf>c*a`X)` zn|mA%deCU?|6#5XeWwb~4fsey^xgF;cHY42-qv_S@mEYeL*TDHp}L2ubvZG)@bA>S zJATHKYWVhQbaMnDZ`~OTc&tcvvTfwk{ zJ@3X|dS}i_&zADOIKuiOV%Y9&>~4b9)z#A2*l48=pDG992~i&^#P>4s_AbTMG1#-@ z?N8!9!MbAdwo`p3BqX?9dwtjMIKHLyUGfgT9ONU4_gG+bbhHotcMxL^JjTGBqSjsD zI?$hFpJiZwo;*FJV9WUq(_YYv9EQOok$mJK-^1kqdC)WJ3nU3iL3ft93)sc*?zmlZ zMbDe-m-AOsQ)7PD=H}-5vM29TY;Y2Y@eDDi@tg!^9I@}=8_C3_UXq=iy&Tx9)&>Lw z$mPqI^&7C#bF*VTdMJXvMxoWMaf#D1{i~>F_WQDOHVz@+NoR9m) yI140vG+x3-Hc6l + + + + + + {children} + + + + + + + + ) +} diff --git a/infrastructure/application/components.json b/infrastructure/application/components.json new file mode 100644 index 0000000..2578f4a --- /dev/null +++ b/infrastructure/application/components.json @@ -0,0 +1,17 @@ +{ + "$schema": "https://ui.shadcn.com/schema.json", + "style": "default", + "rsc": true, + "tsx": true, + "tailwind": { + "config": "tailwind.config.js", + "css": "app/globals.css", + "baseColor": "slate", + "cssVariables": false, + "prefix": "" + }, + "aliases": { + "components": "@/components", + "utils": "@/lib/utils" + } +} \ No newline at end of file diff --git a/infrastructure/application/components/aceternity-ui/3d-pin.tsx b/infrastructure/application/components/aceternity-ui/3d-pin.tsx new file mode 100644 index 0000000..f78491f --- /dev/null +++ b/infrastructure/application/components/aceternity-ui/3d-pin.tsx @@ -0,0 +1,165 @@ +"use client"; +import React, { useState } from "react"; +import { motion } from "framer-motion"; +import { cn } from "@/lib/cn"; + +export const PinContainer = ({ + children, + title, + href, + className, + containerClassName, +}: { + children: React.ReactNode; + title?: string; + href: string; + className?: string; + containerClassName?: string; +}) => { + const [transform, setTransform] = useState( + "translate(-50%,-50%) rotateX(0deg)" + ); + + const onMouseEnter = () => { + setTransform("translate(-50%,-50%) rotateX(40deg) scale(0.8)"); + }; + const onMouseLeave = () => { + setTransform("translate(-50%,-50%) rotateX(0deg) scale(1)"); + }; + + return ( +
+
+
+
{children}
+
+
+ +
+ ); +}; + +export const PinPerspective = ({ + title, + href, +}: { + title?: string; + href: string; +}) => { + return ( + +
+ + +
+ <> + + + + +
+ + <> + + + + + +
+
+ ); +}; diff --git a/infrastructure/application/components/aceternity-ui/moving-border.tsx b/infrastructure/application/components/aceternity-ui/moving-border.tsx new file mode 100644 index 0000000..6823d3c --- /dev/null +++ b/infrastructure/application/components/aceternity-ui/moving-border.tsx @@ -0,0 +1,139 @@ +"use client"; +import React from "react"; +import { + motion, + useAnimationFrame, + useMotionTemplate, + useMotionValue, + useTransform, +} from "framer-motion"; +import { useRef } from "react"; +import { cn } from "@/lib/cn"; + +export function Button({ + borderRadius = "1.75rem", + children, + as: Component = "button", + containerClassName, + borderClassName, + duration, + className, + ...otherProps +}: { + borderRadius?: string; + children: React.ReactNode; + as?: any; + containerClassName?: string; + borderClassName?: string; + duration?: number; + className?: string; + [key: string]: any; +}) { + return ( + +
+ +
+ +
+ +
+ {children} +
+ + ); +} + +export const MovingBorder = ({ + children, + duration = 2000, + rx, + ry, + ...otherProps +}: { + children: React.ReactNode; + duration?: number; + rx?: string; + ry?: string; + [key: string]: any; +}) => { + const pathRef = useRef(); + const progress = useMotionValue(0); + + useAnimationFrame((time) => { + const length = pathRef.current?.getTotalLength(); + if (length) { + const pxPerMillisecond = length / duration; + progress.set((time * pxPerMillisecond) % length); + } + }); + + const x = useTransform( + progress, + (val) => pathRef.current?.getPointAtLength(val).x + ); + const y = useTransform( + progress, + (val) => pathRef.current?.getPointAtLength(val).y + ); + + const transform = useMotionTemplate`translateX(${x}px) translateY(${y}px) translateX(-50%) translateY(-50%)`; + + return ( + <> + + + + + {children} + + + ); +}; diff --git a/infrastructure/application/components/app/hero.tsx b/infrastructure/application/components/app/hero.tsx new file mode 100644 index 0000000..2921219 --- /dev/null +++ b/infrastructure/application/components/app/hero.tsx @@ -0,0 +1,53 @@ +import { ArrowRight, ChevronRight } from 'lucide-react'; +import React from 'react' +import { Button } from '../shadcn-ui/button' +import { useRouter } from 'next/navigation'; +import Link from 'next/link'; +import Image from 'next/image'; +const Hero = () => { + + return ( +
+

Make easy money with financial investment

+

in one DynamiteTrade

+
+

Browse all stocks, keep track of any stock changes

+

and get immediate financial advice from DT advisors

+
+
+ + + + + +

Sign in as Admin

+ + + +
+
+
+
+ landing page animation +
+ +
+ ) +} + +export default Hero \ No newline at end of file diff --git a/infrastructure/application/components/app/home-navbar.tsx b/infrastructure/application/components/app/home-navbar.tsx new file mode 100644 index 0000000..ad3b16c --- /dev/null +++ b/infrastructure/application/components/app/home-navbar.tsx @@ -0,0 +1,43 @@ +"use client" + +import Image from 'next/image' +import Link from 'next/link' +import { useTheme } from 'next-themes' +import { Button } from '../shadcn-ui/button' +const HomeNavbar = () => { + const { setTheme } = useTheme() + setTheme("light") + return ( +
+
+ Logo +

+ DynamiteTrade. +

+
+ + +
+
Home
+
Pricing
+
About Us
+ + + +
+
+ ) +} + +export default HomeNavbar \ No newline at end of file diff --git a/infrastructure/application/components/app/image-component.tsx b/infrastructure/application/components/app/image-component.tsx new file mode 100644 index 0000000..d365e45 --- /dev/null +++ b/infrastructure/application/components/app/image-component.tsx @@ -0,0 +1,13 @@ +import Image from "next/image" + + +export const ImageComponent = () => { + return ( + tsla product + ) +} \ No newline at end of file diff --git a/infrastructure/application/components/app/image-fallback.tsx b/infrastructure/application/components/app/image-fallback.tsx new file mode 100644 index 0000000..affaf8e --- /dev/null +++ b/infrastructure/application/components/app/image-fallback.tsx @@ -0,0 +1,30 @@ +import React, { useState } from 'react'; +import Image from 'next/image'; + +interface ImageWithFallbackProps { + props: { + src: string, + fallbackSrc: string, + width: number, + height: number, + className: string + } +} +const ImageWithFallback = ({ props }: ImageWithFallbackProps) => { + const { src, fallbackSrc, width, height } = props; + const [imgSrc, setImgSrc] = useState(src); + + return ( + product img { + setImgSrc(fallbackSrc); + }} + /> + ); +}; + +export default ImageWithFallback; \ No newline at end of file diff --git a/infrastructure/application/components/app/navbar.tsx b/infrastructure/application/components/app/navbar.tsx new file mode 100644 index 0000000..19f9da7 --- /dev/null +++ b/infrastructure/application/components/app/navbar.tsx @@ -0,0 +1,38 @@ +"use client" + +import { UserButton } from "@clerk/nextjs" +import { usePathname } from "next/navigation" +import { useEffect, useState } from "react" +import { ThemeSwitch } from "./theme-switch" + +const getName = (pathName: string): string => { + if (pathName.includes("dashboard")) return "Dashboard" + else if (pathName.includes("market")) return "Market" + else return "Messenger" +} + +export function Navbar() { + + const [isClient, setIsClient] = useState(false) + + useEffect(() => { + setIsClient(true) + }, []) + if (!isClient) { + return; + } + const pathName = usePathname(); + const name = getName(pathName) + return ( +
+

+ {name} +

+ +
+ + +
+
+ ) +} diff --git a/infrastructure/application/components/app/sidebar.tsx b/infrastructure/application/components/app/sidebar.tsx new file mode 100644 index 0000000..d9eb50d --- /dev/null +++ b/infrastructure/application/components/app/sidebar.tsx @@ -0,0 +1,72 @@ +"use client" + +import { useAnimation } from "@/hooks/use-animation"; +import { cn } from "@/lib/utils"; +import { LayoutDashboard } from "lucide-react"; +import { useTheme } from "next-themes"; +import Image from "next/image"; +import Link from "next/link"; +import { usePathname, useRouter } from "next/navigation"; +import { CiSettings } from "react-icons/ci"; +import { IoStorefrontOutline } from "react-icons/io5"; +import { TbMessageBolt } from "react-icons/tb"; + +const Sidebar = () => { + const { theme } = useTheme() + const pathName = usePathname() + const { animatedId, setAnimatedId } = useAnimation() + const sidebarItems = [ + { + name: "Dashboard", + href: "/dashboard", + icon: LayoutDashboard, + isActive: pathName === "/dashboard" + }, + { + name: "Market", + href: "/market", + icon: IoStorefrontOutline, + isActive: pathName === "/market" + }, + { + name: "Message", + href: `/chat`, + icon: TbMessageBolt, + isActive: pathName === "/chat" + }, + + ] + + const src = theme === 'light' ? "/landing-page/logo.webp" : "/landing-page/logo2.png" + + + return ( +
+ company logo + +
+ {sidebarItems.map((item) => ( + setAnimatedId(2)} + className={cn("py-2 flex items-center justify-center p-3 group rounded-lg text-muted-foreground hover:cursor-pointer hover:text-white hover:bg-[#230f61] shadow-md dark:shadow-sm dark:shadow-purple-700", + item.isActive ? "bg-[#6149cd] text-white shadow-sm shadow-[#19033e] " : "hover:bg-[#6149cd]/30" + )} + > + + + ))} +
+ + +
+ ) +} +export default Sidebar; diff --git a/infrastructure/application/components/app/theme-provider.tsx b/infrastructure/application/components/app/theme-provider.tsx new file mode 100644 index 0000000..8c90fbc --- /dev/null +++ b/infrastructure/application/components/app/theme-provider.tsx @@ -0,0 +1,9 @@ +"use client" + +import * as React from "react" +import { ThemeProvider as NextThemesProvider } from "next-themes" +import { type ThemeProviderProps } from "next-themes/dist/types" + +export function ThemeProvider({ children, ...props }: ThemeProviderProps) { + return {children} +} diff --git a/infrastructure/application/components/app/theme-switch.tsx b/infrastructure/application/components/app/theme-switch.tsx new file mode 100644 index 0000000..24369da --- /dev/null +++ b/infrastructure/application/components/app/theme-switch.tsx @@ -0,0 +1,44 @@ +"use client" + +import * as React from "react" +import { Moon, Sun } from "lucide-react" +import { useTheme } from "next-themes" + +import { Button } from "@/components/shadcn-ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/shadcn-ui/dropdown-menu" + +export function ThemeSwitch() { + const { setTheme } = useTheme() + + return ( + + + + + + setTheme("light")}> + Light + + setTheme("dark")}> + Dark + + setTheme("system")}> + System + + + + ) +} diff --git a/infrastructure/application/components/providers/providers.tsx b/infrastructure/application/components/providers/providers.tsx new file mode 100644 index 0000000..cf42290 --- /dev/null +++ b/infrastructure/application/components/providers/providers.tsx @@ -0,0 +1,15 @@ +"use client" + +import { QueryClient, QueryClientProvider } from "@tanstack/react-query" + +const client = new QueryClient() + +const Providers = ({ children }: { children: React.ReactNode }) => { + return ( + + {children} + + ) +} + +export default Providers \ No newline at end of file diff --git a/infrastructure/application/components/providers/toast-provider.tsx b/infrastructure/application/components/providers/toast-provider.tsx new file mode 100644 index 0000000..8fe8b89 --- /dev/null +++ b/infrastructure/application/components/providers/toast-provider.tsx @@ -0,0 +1,5 @@ +import { Toaster } from "react-hot-toast" + +export const ToastProvider = () => { + return +} \ No newline at end of file diff --git a/infrastructure/application/components/shadcn-ui/accordion.tsx b/infrastructure/application/components/shadcn-ui/accordion.tsx new file mode 100644 index 0000000..24c788c --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/accordion.tsx @@ -0,0 +1,58 @@ +"use client" + +import * as React from "react" +import * as AccordionPrimitive from "@radix-ui/react-accordion" +import { ChevronDown } from "lucide-react" + +import { cn } from "@/lib/utils" + +const Accordion = AccordionPrimitive.Root + +const AccordionItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AccordionItem.displayName = "AccordionItem" + +const AccordionTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + svg]:rotate-180", + className + )} + {...props} + > + {children} + + + +)) +AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName + +const AccordionContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + +
{children}
+
+)) + +AccordionContent.displayName = AccordionPrimitive.Content.displayName + +export { Accordion, AccordionItem, AccordionTrigger, AccordionContent } diff --git a/infrastructure/application/components/shadcn-ui/button.tsx b/infrastructure/application/components/shadcn-ui/button.tsx new file mode 100644 index 0000000..6f20b1e --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/button.tsx @@ -0,0 +1,58 @@ +import * as React from "react" +import { Slot } from "@radix-ui/react-slot" +import { cva, type VariantProps } from "class-variance-authority" + +import { cn } from "@/lib/utils" + +const buttonVariants = cva( + "inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50", + { + variants: { + variant: { + default: "bg-primary text-primary-foreground hover:bg-primary/90", + destructive: + "bg-destructive text-destructive-foreground hover:bg-destructive/90", + outline: + "border border-input bg-background hover:bg-accent hover:text-accent-foreground", + secondary: + "bg-secondary text-secondary-foreground hover:bg-secondary/80", + ghost: "hover:bg-accent hover:text-accent-foreground", + link: "text-primary underline-offset-4 hover:underline", + purple: "bg-[#500480] text-primary-foreground hover:bg-primary/90", + custom: "bg-none text-muted-foreground" + }, + size: { + default: "h-10 px-4 py-2", + sm: "h-9 rounded-md px-3", + lg: "h-11 rounded-md px-8", + icon: "h-10 w-10", + }, + }, + defaultVariants: { + variant: "default", + size: "default", + }, + } +) + +export interface ButtonProps + extends React.ButtonHTMLAttributes, + VariantProps { + asChild?: boolean +} + +const Button = React.forwardRef( + ({ className, variant, size, asChild = false, ...props }, ref) => { + const Comp = asChild ? Slot : "button" + return ( + + ) + } +) +Button.displayName = "Button" + +export { Button, buttonVariants } diff --git a/infrastructure/application/components/shadcn-ui/card.tsx b/infrastructure/application/components/shadcn-ui/card.tsx new file mode 100644 index 0000000..afa13ec --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/card.tsx @@ -0,0 +1,79 @@ +import * as React from "react" + +import { cn } from "@/lib/utils" + +const Card = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +Card.displayName = "Card" + +const CardHeader = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +CardHeader.displayName = "CardHeader" + +const CardTitle = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardTitle.displayName = "CardTitle" + +const CardDescription = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardDescription.displayName = "CardDescription" + +const CardContent = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardContent.displayName = "CardContent" + +const CardFooter = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +CardFooter.displayName = "CardFooter" + +export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent } diff --git a/infrastructure/application/components/shadcn-ui/checkbox.tsx b/infrastructure/application/components/shadcn-ui/checkbox.tsx new file mode 100644 index 0000000..d1ca508 --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/checkbox.tsx @@ -0,0 +1,30 @@ +"use client" + +import * as React from "react" +import * as CheckboxPrimitive from "@radix-ui/react-checkbox" +import { Check } from "lucide-react" + +import { cn } from "@/lib/utils" + +const Checkbox = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + + +)) +Checkbox.displayName = CheckboxPrimitive.Root.displayName + +export { Checkbox } diff --git a/infrastructure/application/components/shadcn-ui/collapsible.tsx b/infrastructure/application/components/shadcn-ui/collapsible.tsx new file mode 100644 index 0000000..9fa4894 --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/collapsible.tsx @@ -0,0 +1,11 @@ +"use client" + +import * as CollapsiblePrimitive from "@radix-ui/react-collapsible" + +const Collapsible = CollapsiblePrimitive.Root + +const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger + +const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent + +export { Collapsible, CollapsibleTrigger, CollapsibleContent } diff --git a/infrastructure/application/components/shadcn-ui/dialog.tsx b/infrastructure/application/components/shadcn-ui/dialog.tsx new file mode 100644 index 0000000..cf04322 --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/dialog.tsx @@ -0,0 +1,122 @@ +"use client" + +import * as React from "react" +import * as DialogPrimitive from "@radix-ui/react-dialog" +import { X } from "lucide-react" + +import { cn } from "@/lib/utils" + +const Dialog = DialogPrimitive.Root + +const DialogTrigger = DialogPrimitive.Trigger + +const DialogPortal = DialogPrimitive.Portal + +const DialogClose = DialogPrimitive.Close + +const DialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DialogOverlay.displayName = DialogPrimitive.Overlay.displayName + +const DialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + {children} + + + Close + + + +)) +DialogContent.displayName = DialogPrimitive.Content.displayName + +const DialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +DialogHeader.displayName = "DialogHeader" + +const DialogFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +DialogFooter.displayName = "DialogFooter" + +const DialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DialogTitle.displayName = DialogPrimitive.Title.displayName + +const DialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DialogDescription.displayName = DialogPrimitive.Description.displayName + +export { + Dialog, + DialogPortal, + DialogOverlay, + DialogClose, + DialogTrigger, + DialogContent, + DialogHeader, + DialogFooter, + DialogTitle, + DialogDescription, +} diff --git a/infrastructure/application/components/shadcn-ui/dropdown-menu.tsx b/infrastructure/application/components/shadcn-ui/dropdown-menu.tsx new file mode 100644 index 0000000..f69a0d6 --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/dropdown-menu.tsx @@ -0,0 +1,200 @@ +"use client" + +import * as React from "react" +import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu" +import { Check, ChevronRight, Circle } from "lucide-react" + +import { cn } from "@/lib/utils" + +const DropdownMenu = DropdownMenuPrimitive.Root + +const DropdownMenuTrigger = DropdownMenuPrimitive.Trigger + +const DropdownMenuGroup = DropdownMenuPrimitive.Group + +const DropdownMenuPortal = DropdownMenuPrimitive.Portal + +const DropdownMenuSub = DropdownMenuPrimitive.Sub + +const DropdownMenuRadioGroup = DropdownMenuPrimitive.RadioGroup + +const DropdownMenuSubTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, children, ...props }, ref) => ( + + {children} + + +)) +DropdownMenuSubTrigger.displayName = + DropdownMenuPrimitive.SubTrigger.displayName + +const DropdownMenuSubContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DropdownMenuSubContent.displayName = + DropdownMenuPrimitive.SubContent.displayName + +const DropdownMenuContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, sideOffset = 4, ...props }, ref) => ( + + + +)) +DropdownMenuContent.displayName = DropdownMenuPrimitive.Content.displayName + +const DropdownMenuItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +DropdownMenuItem.displayName = DropdownMenuPrimitive.Item.displayName + +const DropdownMenuCheckboxItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, checked, ...props }, ref) => ( + + + + + + + {children} + +)) +DropdownMenuCheckboxItem.displayName = + DropdownMenuPrimitive.CheckboxItem.displayName + +const DropdownMenuRadioItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + + + + {children} + +)) +DropdownMenuRadioItem.displayName = DropdownMenuPrimitive.RadioItem.displayName + +const DropdownMenuLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +DropdownMenuLabel.displayName = DropdownMenuPrimitive.Label.displayName + +const DropdownMenuSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DropdownMenuSeparator.displayName = DropdownMenuPrimitive.Separator.displayName + +const DropdownMenuShortcut = ({ + className, + ...props +}: React.HTMLAttributes) => { + return ( + + ) +} +DropdownMenuShortcut.displayName = "DropdownMenuShortcut" + +export { + DropdownMenu, + DropdownMenuTrigger, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuCheckboxItem, + DropdownMenuRadioItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuShortcut, + DropdownMenuGroup, + DropdownMenuPortal, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuRadioGroup, +} diff --git a/infrastructure/application/components/shadcn-ui/form.tsx b/infrastructure/application/components/shadcn-ui/form.tsx new file mode 100644 index 0000000..7420d03 --- /dev/null +++ b/infrastructure/application/components/shadcn-ui/form.tsx @@ -0,0 +1,176 @@ +import * as React from "react" +import * as LabelPrimitive from "@radix-ui/react-label" +import { Slot } from "@radix-ui/react-slot" +import { + Controller, + ControllerProps, + FieldPath, + FieldValues, + FormProvider, + useFormContext, +} from "react-hook-form" + +import { cn } from "@/lib/utils" +import { Label } from "@/components/shadcn-ui/label" + +const Form = FormProvider + +type FormFieldContextValue< + TFieldValues extends FieldValues = FieldValues, + TName extends FieldPath = FieldPath +> = { + name: TName +} + +const FormFieldContext = React.createContext( + {} as FormFieldContextValue +) + +const FormField = < + TFieldValues extends FieldValues = FieldValues, + TName extends FieldPath = FieldPath +>({ + ...props +}: ControllerProps) => { + return ( + + + + ) +} + +const useFormField = () => { + const fieldContext = React.useContext(FormFieldContext) + const itemContext = React.useContext(FormItemContext) + const { getFieldState, formState } = useFormContext() + + const fieldState = getFieldState(fieldContext.name, formState) + + if (!fieldContext) { + throw new Error("useFormField should be used within ") + } + + const { id } = itemContext + + return { + id, + name: fieldContext.name, + formItemId: `${id}-form-item`, + formDescriptionId: `${id}-form-item-description`, + formMessageId: `${id}-form-item-message`, + ...fieldState, + } +} + +type FormItemContextValue = { + id: string +} + +const FormItemContext = React.createContext( + {} as FormItemContextValue +) + +const FormItem = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => { + const id = React.useId() + + return ( + +
+ + ) +}) +FormItem.displayName = "FormItem" + +const FormLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => { + const { error, formItemId } = useFormField() + + return ( +