[機械学習・進化計算による株式取引最適化] No.03-02 DQNによる学習

このプログラムの目的

このプログラムはDQN(Deep Q-Learning)と呼ばれる強化学習によって,前項で作成したシミュレーション環境を学習することにあります.

work_share
├02_get_stock_price
└03_dqn_learning
  ├Dockerfile
  ├docker-compose.yml
  └src
    ├draw_graph
    |  └draw_tools.py
    ├enviroment
    |  └stock_env.py
    ├reinforcement_learning
    |  └dqn.py  (これを作成)
    ├result(自動生成)
    └experiment01.py

使用ライブラリ

import pandas as pd

import json
import os

import tqdm
import environment.stock_env as rl_env
import draw_graph.draw_tools as draw_tools

import torch

from pfrl import explorers, q_functions
import pfrl

学習過程

基本的にはdfにデータセット,argsにパラメータ類を入れて渡す形式をとっています.

  • Qネットーワークのサイズは256ユニット×4層にしました.本当は要調整です.
  • リプレイメモリのバッファサイズはenv.max_data_length*1000としており,凡そ1000件の銘柄が保存できるようにしてあります.
  • エピソードを回す際にargs['resample_interval']回毎に銘柄コードを変更するようにしています.
  • 初期金がargs['money_lower_limit_rate']以下になったらリセットがかかるようにしています.
  • env.evaluate_codesにはその銘柄の良さを保存するために,利益率のエピソード合計と学習したエピソード数を記録します.
  • また1000回エピソードを学習するごとに現在の記録を保存するようにしています.
def learning(df, args):
    env = rl_env.stock_env(
        init_money=args['init_money'],
        trade_cost_func=args['trade_cost_func'],
        purchase_max=args['purchase_max'],
        reward_last_only=args['reward_last_only']
    )

    env.set_df(df)
    env.resample()
    print(f'input num : {env.input_num}, max data length : {env.max_data_length}')
    q_func = q_functions.FCStateQFunctionWithDiscreteAction(
        env.input_num,
        env.action_num,
        n_hidden_channels=256,
        n_hidden_layers=4,
    )
    optimizer = torch.optim.AdamW(q_func.parameters())
    gamma = args['gamma']

    explorer = pfrl.explorers.ConstantEpsilonGreedy(
        epsilon=args['epsilon'], random_action_func=env.action_space.sample)

    replay_buffer = pfrl.replay_buffers.ReplayBuffer(capacity=env.max_data_length*1000)

    gpu = 0
    agent = pfrl.agents.DoubleDQN(
        q_func,
        optimizer,
        replay_buffer,
        gamma,
        explorer,
        replay_start_size=args['replay_start_size'],
        update_interval=args['update_interval'],
        target_update_interval=args['target_update_interval'],
        gpu=gpu,
        minibatch_size=args['minibatch_size']
    )
    n_episodes = args['n_episodes']


    result = {
        'training_log':[],
        'last_tuning_log':[],
        'train_eval_ret':[],
        'test_eval_ret':[],
        'use_codes':[]
    }

    for i in range(1, n_episodes + 1):
        if (i % args['resample_interval']) ==0:
            env.resample()
            env.evaluate_codes[env.selected_code] = (0, 0)
        obs = env.reset()
        R = 0  # sum of rewards
        t = 0  # time step
        for t in range(len(env.data['X'])):
            action = agent.act(obs)
            obs, reward, done, info = env.step(action)
            R += reward
            t += 1
            reset = env.total_assets < env.init_money*args['money_lower_limit_rate']
            agent.observe(obs, reward, done, reset)
            if done or reset:
                break

        return_rate = env.total_assets/env.init_money
        sum_rate, count = env.evaluate_codes[env.selected_code]
        env.evaluate_codes[env.selected_code] = (sum_rate+return_rate, count+1)

        print(f'episode: {i}, R: {R}, last money : {env.money}, total assets : {env.total_assets} ({return_rate}%)')
        statistics = agent.get_statistics()
        print('statistics:', statistics)
        result['training_log'].append({
            'i':i,
            'R':R,
            'total_assets':env.total_assets,
        })
        for key, value in statistics:
            result['training_log'][-1][key] = value

        if (i % 1000)==0:
            #(save_temp_result use some codes. not all codes)
            save_temp_result(i, args, agent, env, result)
            # use all codes and use train_data
            env.resample()
    print('Finished.')

記録の保存

記録の保存では次のステップで複数の銘柄を取引した際のシミュレーション記録を取ります.

  1. env.set_eval_use_codes(borderline=1.01)にて,平均利益率が101%以上の記録を出せた銘柄を選びます.
  2. 選ばれた銘柄のみで実際にシミュレーションを回してみて,破滅的忘却(catastrophic forgetting)が起きていないことを確認します(final_eval_use_codes).忘却している銘柄は外します.
  3. 学習データについて複数銘柄の取引をした際の記録を行います(env.eval_mode('train')).
  4. テストデータについて複数銘柄の取引をした際の記録を行います(env.eval_mode('test')).
  5. 各記録について保存します.
def save_temp_result(i, args, agent, env, result):
    temp_result = {
        'train_eval_ret':[],
        'test_eval_ret':[],
        'use_codes':[],
        'final_use_codes':[]
    }
    save_dir = f'{args["result_dir"]}/temp_result/{i}'
    os.makedirs(save_dir, exist_ok=True)
    print(f'{i} save : temp result' )
    env.set_eval_use_codes(borderline=1.01)
    print(f'use codes : {env.eval_use_codes}')
    temp_result['use_codes'] = [int(code) for code in env.eval_use_codes]

    with open(f'{save_dir}/eval_use_codes.json', 'w') as f:
        json.dump(temp_result['use_codes'], f, indent=4, ensure_ascii=False)

    agent.save(f'{save_dir}/model')

    with agent.eval_mode():
        print('determine the final eval use codes')
        final_eval_use_codes = []
        for code in tqdm.tqdm(env.eval_use_codes):
            env.resample(code = code)
            obs = env.reset()
            for t in range(len(env.data['X'])):
                action = agent.act(obs)
                obs, reward, done, info = env.step(action)
                reset = env.total_assets < env.init_money*args['money_lower_limit_rate']
                agent.observe(obs, reward, done, reset)
                if done or reset:
                    break
            return_rate = env.total_assets/env.init_money
            if return_rate > 1.01:
                final_eval_use_codes.append(code)
        env.eval_use_codes = final_eval_use_codes
        print(f'final use codes : {env.eval_use_codes}')
        temp_result['final_use_codes'] = [int(code) for code in env.eval_use_codes]
        with open(f'{save_dir}/final_eval_use_codes.json', 'w') as f:
            json.dump(temp_result['final_use_codes'], f, indent=4, ensure_ascii=False)
        if len(env.eval_use_codes) > 0:
            print('eval train')
            env.eval_mode('train')
            obs = env.reset()
            R = 0
            for t in tqdm.tqdm(range(env.max_data_length)):
                action = agent.act(obs)
                obs, r, done, info = env.step(action)
                R += r
                reset = done
                #agent.observe(obs, r, done, reset)
                temp_result['train_eval_ret'].append({
                    'date':info['date'],
                    'code':info['code'],
                    'money':env.money,
                    'total_assets':env.total_assets,
                    'all_sold_money':env.all_sold_money
                })
            print(f'train  evaluation episode: {i}, R: {R}, last money : {env.money}, total assets : {env.total_assets} ({100*env.total_assets/env.init_money}%)')

            print('eval test')
            env.eval_mode('test')
            obs = env.reset()
            R = 0
            for t in tqdm.tqdm(range(env.max_data_length)):
                action = agent.act(obs)
                obs, r, done, info = env.step(action)
                R += r
                reset = done
                #agent.observe(obs, r, done, reset)
                temp_result['test_eval_ret'].append({
                    'date':info['date'],
                    'code':info['code'],
                    'money':env.money,
                    'total_assets':env.total_assets,
                    'all_sold_money':env.all_sold_money
                })
            print(f'test  evaluation episode: {i}, R: {R}, last money : {env.money}, total assets : {env.total_assets} ({100*env.total_assets/env.init_money}%)')

            save_path = f'{save_dir}/eval_assets_train.jpg'
            draw_tools.plot_total_assets_graph(pd.DataFrame(temp_result['train_eval_ret']), save_path, 'eval train assets')

            save_path = f'{save_dir}/eval_assets_test.jpg'
            draw_tools.plot_total_assets_graph(pd.DataFrame(temp_result['test_eval_ret']), save_path, 'eval test assets')

            save_path = f'{save_dir}/eval_money_train.jpg'
            draw_tools.plot_money_graph(pd.DataFrame(temp_result['train_eval_ret']), save_path, 'money', 'money')

            save_path = f'{save_dir}/eval_money_test.jpg'
            draw_tools.plot_money_graph(pd.DataFrame(temp_result['test_eval_ret']), save_path, 'money', 'money')

            save_path = f'{save_dir}/eval_all_sold_money_train.jpg'
            draw_tools.plot_money_graph(pd.DataFrame(temp_result['train_eval_ret']), save_path, 'all_sold_money', 'all_sold_money')

            save_path = f'{save_dir}/eval_all_sold_money_test.jpg'
            draw_tools.plot_money_graph(pd.DataFrame(temp_result['test_eval_ret']), save_path, 'all_sold_money', 'all_sold_money')

        for col in ['R', 'total_assets', 'average_q', 'average_loss']:
            save_path = f'{save_dir}/training_log_{col}.jpg'
            draw_tools.plot_log(pd.DataFrame(result['training_log']), save_path, col)
タイトルとURLをコピーしました