[儲からない競馬予想AI] Section 03-01 : データ整形④ 多出力で学習データの作成

Section01-04で作った学習データを、多出力データへと変形します。

ベースとなるプログラムはSection01-04で書いたコードです。

レースデータを1つの行データとする

変更するのは、Section01-04で書いたコードのrace_data_to_df関数です。

この関数を変更し、新たにrace_data_to_dict関数を作ります。
行っている処理はほとんど前と同じですが、一つのDataFrameを作るのではなく、一つの辞書を作るようにします。また、keyとして’horse_{data[“horse_number”]}_’を元のkeyに付け足します。この辞書を全レース文集めて、DataFrameに直します。

def race_data_to_dict(race_data, race_id, place_id):
    with open('formatted_source_data/horse_name2id_dict.pkl', 'rb') as f:
        horse_name_to_id_dict = pickle.load(f)
    with open('formatted_source_data/jocky_name2id_dict.pkl', 'rb') as f:
        jocky_name_to_id_dict = pickle.load(f)
    with open('formatted_source_data/ped_name2id_dict.pkl', 'rb') as f:
        ped_name_to_id_dict = pickle.load(f)

    race_info = race_data['race_place_data']
    horses_data = race_data['race_horses_data']
    race_payout_data = race_data['race_payout_data']
    # 1st step: calc statistics data of previous race infomation 
    race_total_prize = sum([data['prize'] for data in horses_data if data['prize'] != None])
    ret = {
        'race_date':race_info['race_date'],
        'race_id':race_id,
        'place_id':place_id,
        'race_grade':race_info['race_grade'],
        'race_distance':race_info['race_distance'],
        'race_type':race_type_pre_process(race_info['race_type']),
        'race_total_prize':race_total_prize,
        'weather':weather_pre_process(race_info['weather']),
        'race_condition':race_condition_pre_process(race_info['race_condition']),
        'horse_count':len(horses_data),
    }

    for n,horse_data_dict in enumerate(horses_data):
        data = {
            'waku':horse_data_dict['waku'],
            'horse_number':horse_data_dict['horse_number'],
            'name':horse_data_dict['name'],
            'sex':horse_data_dict['sex'],
            'age':horse_data_dict['age'],
            'jocky_weight':horse_data_dict['jocky_weight'],
            'jocky_name':horse_data_dict['jocky_name'],
            'odds':horse_data_dict['odds'],
            'popular':horse_data_dict['popular'],
            'weight':horse_data_dict['weight'],
            'weight_sub':horse_data_dict['weight_sub'],
            'rank':horse_data_dict['rank'],
            'time':horse_data_dict['time'],
            'prize':horse_data_dict['prize'],
            'tansyo_hit':0,
            'tansyo_payout':0,
            'hukusyo_hit':0,
            'hukusyo_payout':0
        }
        if int(race_payout_data['tansyo_ret']) == data['horse_number']:
            data['tansyo_hit'] = 1
            data['tansyo_payout'] = string_to_number(race_payout_data['tansyo_payout'])/100

        for j, hit_horse_number in enumerate(race_payout_data['hukusyo_ret']):
            if int(hit_horse_number) == data['horse_number']:
                data['hukusyo_hit'] = 1
                data['hukusyo_payout'] = string_to_number(race_payout_data['hukusyo_payout'][j])/100

        # add pre_race_info
        horse_name = data['name']
        horse_id = horse_name_to_id_dict[horse_name]

        pre_horse_data = get_horse_pre_info(horse_id, ret['race_date'])
        for key, value in pre_horse_data.items():
            data[key] = value

        # add pre_race_jocky_info
        jocky_name = data['jocky_name']
        jocky_id = jocky_name_to_id_dict[jocky_name]

        pre_jocky_data = get_jocky_pre_info(jocky_id, ret['race_date'])
        for key, value in pre_jocky_data.items():
            data[key] = value

        #add ped name info
        for i, ped_name in enumerate(horse_data_dict['ped_data']):
            data[f'ped_{i}'] = ped_name_to_id_dict[ped_name]
        
        for key, value in data.items():
            ret[f'horse_{data["horse_number"]}_{key}'] = value

    return ret

main

mainの処理はぼ前と変化していません。保存する名前を’formatted_source_data/analysis_data03.pkl’にすることを忘れないでください。

if __name__ == '__main__':
    print('load database')
    database = []
    race_id = 0
    for place_id in range(1,11):
        for year in range(25):
            database_path = f'keiba_source_data/20{year:02}_{place_id:02}.data'
            with open(database_path, 'rb') as f:
                #print(database_path)
                for data in pickle.load(f):
                    if data != {} and len(data['race_horses_data']) >= 5:
                        database.append([data, race_id, place_id])
                        #print(race_data_to_df(data, race_id, place_id))
                        #exit()
                        race_id += 1
    if len(database) == 0:
        print('database load error : size 0')
        exit()
    print(f'database num : {len(database)}')
    print('race_data to learning_data')
    args = [(race_data, race_id, place_id) for race_data, race_id, place_id in database]
    rets = joblib.Parallel(n_jobs=10, verbose=1)([joblib.delayed(race_data_to_dict)(*arg) for arg in args])
    df = pd.DataFrame(rets)
    df = df.sort_values('race_date').reset_index(drop=True)
    save_name = f'formatted_source_data/analysis_data03.pkl'
    df.to_pickle(save_name)
タイトルとURLをコピーしました