★★★ This article is from the AlStudio community boutique project, [click here] to view more boutique content >>>
Flying paddle learning competition: League of Legends masters predict 85.365 points in February 2023
Contest Links - Paddle Learn: League of Legends Masters Predictions
Introduction
This question is a typical classification problem, with the League of Legends mobile game as the background, requiring players to predict the player's win or loss in the game based on the real-time game data of the League of Legends players.
the data shows
Each row in the data set is the game data of a player, and the data fields are as follows:
- id: player record id
- win: win or not, label variable
- kills: number of kills
- deaths: number of deaths
- assists: the number of assists
- largestkillingspree: The largest killing spree (game term, meaning big kills. When you kill three opponent heroes in a row without dying in the middle)
- largestmultikill: Maximum mult ikill (game terminology, multiple kills in a short period of time)
- longesttimespentliving: the longest survival time
- doublekills: the number of doublekills
- triplekills: times of doublekills
- quadrakills: number of quadrakills
- pentakills: number of pentakills
- totdmgdealt: total damage
- magicdmgdealt: magic damage
- physicaldmgdealt: physical damage
- truedmgdealt: true damage
- largestcrit: maximum crit damage
- totdmgtochamp: damage to opposing players
- magicdmgtochamp: magic damage to opposing players
- physdmgtochamp: Physical damage to opposing players
- truedmgtochamp: true damage to opposing players
- totheal: treatment amount
- totunitshealed: Total units healed
- dmgtoturrets: damage to turrets
- timecc: Legal control time
- totdmgtaken: damage taken
- magicdmgtaken: magic damage taken
- physdmgtaken: physical damage taken
- truedmgtaken: true damage taken
- wardsplaced: Number of scout wards placed
- wardskilled: number of scout wards destroyed
- firstblood: whether it is firstblood
The label field win in the test set is empty, and players need to predict.
game difficulty
The difficulty of this competition mainly lies in data processing and model selection
- The game data has neither missing values nor particularly obvious outliers. The data distribution is relatively normal. There are not many places that can increase the accuracy rate in feature engineering.
- In fact, even without any data processing, the random forest example, LightGBM and xgboost models can easily get 82-84 points. The neural network model cannot make much progress and is not even as good as the tree model.
Project Highlights
-
Using genetic algorithm to construct more features and using PCA algorithm for feature dimension reduction
-
Using the Dense Connetion idea to design the network structure to prevent the gradient from disappearing
code repository
https://github.com/ZhangzrJerry/paddle-lolmp
- decompress data
!unzip -d data/ data/data137276/test.csv.zip !unzip -d data/ data/data137276/train.csv.zip !pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/ !mkdir result pretrain
Archive: data/data137276/test.csv.zip inflating: data/test.csv Archive: data/data137276/train.csv.zip inflating: data/train.csv Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/ Collecting category_encoders Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8b/9f/09d149aab1296254fe83e34aaddc59ade820acb15d529773d753df3384bf/category_encoders-2.6.0-py2.py3-none-any.whl (81 kB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.2/81.2 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m [?25hRequirement already satisfied: seaborn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.10.0) Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (1.19.5) Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.1.5) Requirement already satisfied: sklearn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (0.0) Collecting gplearn Downloading https://pypi.tuna.tsinghua.edu.cn/packages/cc/b0/063b2ddfd9258c4f43abd3b5d13ca94b53e9479a3f21df18fafe3948b67d/gplearn-0.4.2-py3-none-any.whl (25 kB) Requirement already satisfied: IPython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (7.34.0) Collecting pydotplus Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/bf/62567830b700d9f6930e9ab6831d6ba256f7b0b730acb37278b0ccdffacf/pydotplus-2.0.2.tar.gz (278 kB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.7/278.7 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m [?25h Preparing metadata (setup.py) ... [?25ldone [?25hRequirement already satisfied: scikit-learn>=0.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from category_encoders->-r requirements.txt (line 1)) (0.24.2) Collecting statsmodels>=0.9.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/8e/062b268b8e6d19382cbf2f97ac0384285511790718ce90bbfb1eb5e44b07/statsmodels-0.13.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.9 MB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m [?25hRequirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from category_encoders->-r requirements.txt (line 1)) (1.6.3) Collecting patsy>=0.5.1 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/e4/b3263b0e353f2be7b14f044d57874490c9cef1798a435f038683acea5c98/patsy-0.5.3-py2.py3-none-any.whl (233 kB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.8/233.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m [?25hRequirement already satisfied: matplotlib>=2.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seaborn->-r requirements.txt (line 2)) (2.2.3) Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas->-r requirements.txt (line 4)) (2.8.2) Requirement already satisfied: pytz>=2017.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas->-r requirements.txt (line 4)) (2019.3) Collecting scikit-learn>=0.20.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/bd/05/e561bc99a615b5c099c7a9355409e5e57c525a108f1c2e156abb005b90a6/scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.8 MB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.8/24.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m [?25hCollecting joblib>=1.0.0 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/d4/3b4c8e5a30604df4c7518c562d4bf0502f2fa29221459226e140cf846512/joblib-1.2.0-py3-none-any.whl (297 kB) [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.0/298.0 kB[0m [31m682.1 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m [?25hRequirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (4.4.2) Requirement already satisfied: pickleshare in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.7.5) Requirement already satisfied: jedi>=0.16 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.17.2) Requirement already satisfied: pexpect>4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (4.7.0) Requirement already satisfied: traitlets>=4.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (5.4.0) Requirement already satisfied: setuptools>=18.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (56.2.0) Requirement already satisfied: backcall in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.1.0) Requirement already satisfied: pygments in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (2.13.0) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (2.0.10) Requirement already satisfied: matplotlib-inline in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.1.6) Requirement already satisfied: pyparsing>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pydotplus->-r requirements.txt (line 8)) (3.0.9) Requirement already satisfied: parso<0.8.0,>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from jedi>=0.16->IPython->-r requirements.txt (line 7)) (0.7.1) Requirement already satisfied: six>=1.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (1.16.0) Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (0.10.0) Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (1.1.0) Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pexpect>4.3->IPython->-r requirements.txt (line 7)) (0.7.0) Requirement already satisfied: wcwidth in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython->-r requirements.txt (line 7)) (0.1.7) Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.20.0->category_encoders->-r requirements.txt (line 1)) (2.1.0) Requirement already satisfied: packaging>=21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from statsmodels>=0.9.0->category_encoders->-r requirements.txt (line 1)) (21.3) Building wheels for collected packages: pydotplus Building wheel for pydotplus (setup.py) ... [?25ldone [?25h Created wheel for pydotplus: filename=pydotplus-2.0.2-py3-none-any.whl size=24566 sha256=6ae026c23d9c6a62321edd556da9ca4e035a4bc23b4c36f5cc1abb48cd91873b Stored in directory: /home/aistudio/.cache/pip/wheels/5c/0c/20/fd91edec2a19961da82914c46e465380335e60bc30253db979 Successfully built pydotplus Installing collected packages: pydotplus, patsy, joblib, scikit-learn, statsmodels, gplearn, category_encoders Attempting uninstall: joblib Found existing installation: joblib 0.14.1 Uninstalling joblib-0.14.1: Successfully uninstalled joblib-0.14.1 Attempting uninstall: scikit-learn Found existing installation: scikit-learn 0.24.2 Uninstalling scikit-learn-0.24.2: Successfully uninstalled scikit-learn-0.24.2 Successfully installed category_encoders-2.6.0 gplearn-0.4.2 joblib-1.2.0 patsy-0.5.3 pydotplus-2.0.2 scikit-learn-1.0.2 statsmodels-0.13.5 [1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m23.0[0m [1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
import numpy as np import pandas as pd x_train = pd.read_csv('data/train.csv').drop('win',axis=1) y_train = pd.read_csv('data/train.csv')['win'] x_test = pd.read_csv('data/test.csv') # Convenient feature engineering after splicing feature = pd.concat([x_train, x_test])
- data visualization
import matplotlib.pyplot as plt import seaborn as sns # view data information print(feature.describe()) print(feature.info()) # View data distribution characteristics sns.FacetGrid(pd.melt(feature), col="variable", col_wrap=4, sharex=False, sharey=False).map(sns.distplot, "value") # View data dependencies sns.set_context({"figure.figsize":(8,8)}) sns.heatmap(data=feature.corr(),square=True,cmap='RdBu_r')
id kills deaths assists \ count 200000.000000 200000.000000 200000.000000 200000.000000 mean 99999.500000 5.798545 5.810190 8.322870 std 57735.171256 4.605316 3.263815 5.933893 min 0.000000 0.000000 0.000000 0.000000 25% 49999.750000 2.000000 3.000000 4.000000 50% 99999.500000 5.000000 6.000000 7.000000 75% 149999.250000 8.000000 8.000000 12.000000 max 199999.000000 39.000000 23.000000 52.000000 largestkillingspree largestmultikill longesttimespentliving \ count 200000.000000 200000.000000 200000.000000 mean 2.671450 1.332095 630.531655 std 2.537784 0.758037 311.568408 min 0.000000 0.000000 0.000000 25% 0.000000 1.000000 433.000000 50% 2.000000 1.000000 590.000000 75% 4.000000 2.000000 792.000000 max 31.000000 5.000000 3038.000000 doublekills triplekills quadrakills ... totunitshealed \ count 200000.000000 200000.000000 200000.000000 ... 200000.000000 mean 0.540265 0.073220 0.010235 ... 2.253135 std 0.924831 0.295887 0.104548 ... 2.481890 min 0.000000 0.000000 0.000000 ... 0.000000 25% 0.000000 0.000000 0.000000 ... 1.000000 50% 0.000000 0.000000 0.000000 ... 1.000000 75% 1.000000 0.000000 0.000000 ... 3.000000 max 11.000000 7.000000 4.000000 ... 98.000000 dmgtoturrets timecc totdmgtaken magicdmgtaken physdmgtaken \ count 200000.000000 200000.0 200000.000000 200000.000000 200000.000000 mean 2138.209810 0.0 23226.733180 8136.551010 14039.533350 std 2934.306106 0.0 11873.669826 5161.055339 7754.110833 min 0.000000 0.0 0.000000 0.000000 0.000000 25% 0.000000 0.0 15264.000000 4521.000000 8627.000000 50% 986.000000 0.0 21531.000000 7246.000000 12803.000000 75% 3222.250000 0.0 29465.250000 10739.000000 18205.000000 max 55083.000000 0.0 118130.000000 71631.000000 73172.000000 truedmgtaken wardsplaced wardskilled firstblood count 200000.000000 200000.000000 200000.000000 200000.000000 mean 1049.892170 11.508290 1.782860 0.100380 std 1266.146212 7.539761 2.226049 0.300507 min 0.000000 0.000000 0.000000 0.000000 25% 274.000000 7.000000 0.000000 0.000000 50% 656.000000 10.000000 1.000000 0.000000 75% 1352.000000 14.000000 3.000000 0.000000 max 25140.000000 322.000000 48.000000 1.000000 [8 rows x 31 columns] <class 'pandas.core.frame.DataFrame'> Int64Index: 200000 entries, 0 to 19999 Data columns (total 31 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 id 200000 non-null int64 1 kills 200000 non-null int64 2 deaths 200000 non-null int64 3 assists 200000 non-null int64 4 largestkillingspree 200000 non-null int64 5 largestmultikill 200000 non-null int64 6 longesttimespentliving 200000 non-null int64 7 doublekills 200000 non-null int64 8 triplekills 200000 non-null int64 9 quadrakills 200000 non-null int64 10 pentakills 200000 non-null int64 11 totdmgdealt 200000 non-null int64 12 magicdmgdealt 200000 non-null int64 13 physicaldmgdealt 200000 non-null int64 14 truedmgdealt 200000 non-null int64 15 largestcrit 200000 non-null int64 16 totdmgtochamp 200000 non-null int64 17 magicdmgtochamp 200000 non-null int64 18 physdmgtochamp 200000 non-null int64 19 truedmgtochamp 200000 non-null int64 20 totheal 200000 non-null int64 21 totunitshealed 200000 non-null int64 22 dmgtoturrets 200000 non-null int64 23 timecc 200000 non-null int64 24 totdmgtaken 200000 non-null int64 25 magicdmgtaken 200000 non-null int64 26 physdmgtaken 200000 non-null int64 27 truedmgtaken 200000 non-null int64 28 wardsplaced 200000 non-null int64 29 wardskilled 200000 non-null int64 30 firstblood 200000 non-null int64 dtypes: int64(31) memory usage: 48.8 MB None /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/seaborn/distributions.py:288: UserWarning: Data must have variance to compute a kernel density estimate. warnings.warn(msg, UserWarning) <matplotlib.axes._subplots.AxesSubplot at 0x7f6b5a36cf50>
[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-sONywZGi-1677308129213)(main_files/main_5_3.png)]
- data preprocessing
from category_encoders import * from sklearn.preprocessing import MinMaxScaler # Sequential encoding of discrete variables columns= ['kills','deaths','assists', 'largestkillingspree', 'largestmultikill', 'longesttimespentliving', 'doublekills', 'triplekills', 'quadrakills', 'pentakills','firstblood'] m_feature = OrdinalEncoder(cols=columns).fit(feature).transform(feature) m_feature.columns=[ 'id', 'akills', 'adeaths', 'aassists', 'alargestkillingspree', 'alargestmultikill', 'alongesttimespentliving', 'adoublekills', 'atriplekills', 'aquadrakills', 'apentakills', 'totdmgdealt', 'magicdmgdealt', 'physicaldmgdealt', 'truedmgdealt', 'largestcrit', 'totdmgtochamp', 'magicdmgtochamp', 'physdmgtochamp', 'truedmgtochamp', 'totheal', 'totunitshealed', 'dmgtoturrets', 'timecc', 'totdmgtaken', 'magicdmgtaken', 'physdmgtaken', 'truedmgtaken', 'wardsplaced', 'wardskilled', 'afirstblood' ] m_feature=pd.concat([m_feature,feature[columns]],axis=1).drop(['timecc','id'],axis=1) # Normalize the features sc = MinMaxScaler().fit(m_feature) feature = sc.transform(m_feature) pd.DataFrame(feature).to_csv('pretrain/feature.csv',index=False) pd.DataFrame(feature[:180000]).to_csv('pretrain/train_feature.csv',index=False) pd.DataFrame(feature[180000:]).to_csv('pretrain/test_feature.csv',index=False)
- Automatic Mining Features
from gplearn.genetic import SymbolicTransformer from IPython.display import Image import pydotplus import numpy as np import pandas as pd function_set = ['add', 'sub', 'mul', 'div', 'log', 'sqrt', 'abs', 'neg', 'max', 'min'] st = SymbolicTransformer( generations=20, population_size=1000, hall_of_fame=100, n_components=100, function_set=function_set, parsimony_coefficient=0.0005, max_samples=0.9, verbose=1, random_state=0, n_jobs=3 ) st.fit(np.array(pd.read_csv('pretrain/train_feature.csv')),np.array(pd.read_csv('data/train.csv')['win'])) graph = st._best_programs[0].export_graphviz() graph = pydotplus.graphviz.graph_from_dot_data(graph) display(Image(graph.create_png())) pd.DataFrame( np.concatenate( [ st.transform( np.array( pd.read_csv('pretrain/train_feature.csv') ) ), np.array( pd.read_csv('pretrain/train_feature.csv') ) ], axis=1 ) ).to_csv('pretrain/gptrain.csv',index=False) pd.DataFrame( np.concatenate( [ st.transform( np.array( pd.read_csv('pretrain/test_feature.csv') ) ), np.array( pd.read_csv('pretrain/test_feature.csv') ) ], axis=1 ) ).to_csv('pretrain/gptest.csv',index=False)
| Population Average | Best Individual | ---- ------------------------- ------------------------------------------ ---------- Gen Length Fitness Length Fitness OOB Fitness Time Left 0 12.03 0.0967885 4 0.395269 0.413457 2.82m 1 6.91 0.250165 5 0.512841 0.518938 1.67m 2 4.82 0.322645 7 0.600911 0.603908 1.49m 3 5.58 0.365057 12 0.604385 0.606605 1.34m 4 6.48 0.414963 9 0.615257 0.610385 1.35m 5 8.58 0.468763 9 0.628536 0.620473 1.22m 6 11.07 0.508876 15 0.638086 0.641776 1.18m 7 13.61 0.514324 11 0.640324 0.636264 1.12m 8 14.01 0.540671 39 0.656748 0.657258 1.09m 9 15.80 0.544383 39 0.656818 0.656685 1.02m 10 18.48 0.556065 41 0.656622 0.658435 54.90s 11 19.58 0.569885 50 0.66438 0.666361 49.46s 12 20.93 0.58614 39 0.665209 0.665554 46.77s 13 22.09 0.589301 53 0.666591 0.67299 38.94s 14 23.42 0.597276 59 0.666831 0.671305 33.80s 15 25.60 0.600218 65 0.66804 0.668831 28.90s 16 27.89 0.618124 76 0.667599 0.668118 22.03s 17 31.54 0.6234 62 0.668819 0.660909 15.79s 18 33.64 0.631025 55 0.669813 0.669861 7.87s 19 32.85 0.625123 52 0.672489 0.662052 0.00s
[External link picture transfer failed, the source site may have an anti-leeching mechanism, it is recommended to save the picture and upload it directly (img-Q5hdKQun-1677308129215)(main_files/main_9_1.png)]
- feature dimensionality reduction
import pandas as pd from sklearn.decomposition import PCA traindata = pd.read_csv('pretrain/gptrain.csv') testdata = pd.read_csv('pretrain/gptest.csv') transfer = PCA(n_components=64) pd.DataFrame( transfer.fit_transform(traindata) ).to_csv('pretrain/pcatrain.csv',index=False) pd.DataFrame( transfer.transform(testdata) ).to_csv('pretrain/pcatest.csv',index=False)
- model building
from paddle import nn import paddle import paddle.nn.functional as F class DenseBlock(nn.Layer): def __init__(self): super(DenseBlock, self).__init__() self.fc1 = nn.Linear(96, 64) self.fc2 = nn.Linear(64, 32) pass def forward(self, input, dense): x = self.fc1(input) x = self.fc2(x) x = F.relu(x) x = paddle.concat([x, dense],axis=1) return x pass class MyNet(nn.Layer): def __init__(self): super(MyNet, self).__init__() self.fc0 = nn.Linear(64, 96) self.bl1 = DenseBlock() self.bl2 = DenseBlock() self.bl3 = DenseBlock() self.bl4 = DenseBlock() self.bl5 = DenseBlock() self.bl6 = DenseBlock() self.bl7 = DenseBlock() self.bl8 = DenseBlock() self.bl9 = DenseBlock() self.fc4 = nn.Linear(96, 32) self.fc5 = nn.Linear(32, 1) pass def forward(self, input): x = self.fc0(input) x = self.bl1(x, input) x = self.bl2(x, input) x = self.bl3(x, input) x = self.bl4(x, input) x = self.bl5(x, input) x = self.bl6(x, input) x = self.bl7(x, input) x = self.bl8(x, input) x = self.bl9(x, input) x = self.fc4(x) x = self.fc5(x) return F.sigmoid(x) pass paddle.summary(MyNet(),(180000,64))
W0216 17:52:57.241905 5934 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.2, Runtime API Version: 11.2 W0216 17:52:57.244976 5934 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2. ----------------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =================================================================================== Linear-1 [[180000, 64]] [180000, 96] 6,240 Linear-2 [[180000, 96]] [180000, 64] 6,208 Linear-3 [[180000, 64]] [180000, 32] 2,080 DenseBlock-1 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-4 [[180000, 96]] [180000, 64] 6,208 Linear-5 [[180000, 64]] [180000, 32] 2,080 DenseBlock-2 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-6 [[180000, 96]] [180000, 64] 6,208 Linear-7 [[180000, 64]] [180000, 32] 2,080 DenseBlock-3 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-8 [[180000, 96]] [180000, 64] 6,208 Linear-9 [[180000, 64]] [180000, 32] 2,080 DenseBlock-4 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-10 [[180000, 96]] [180000, 64] 6,208 Linear-11 [[180000, 64]] [180000, 32] 2,080 DenseBlock-5 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-12 [[180000, 96]] [180000, 64] 6,208 Linear-13 [[180000, 64]] [180000, 32] 2,080 DenseBlock-6 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-14 [[180000, 96]] [180000, 64] 6,208 Linear-15 [[180000, 64]] [180000, 32] 2,080 DenseBlock-7 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-16 [[180000, 96]] [180000, 64] 6,208 Linear-17 [[180000, 64]] [180000, 32] 2,080 DenseBlock-8 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-18 [[180000, 96]] [180000, 64] 6,208 Linear-19 [[180000, 64]] [180000, 32] 2,080 DenseBlock-9 [[180000, 96], [180000, 64]] [180000, 96] 0 Linear-20 [[180000, 96]] [180000, 32] 3,104 Linear-21 [[180000, 32]] [180000, 1] 33 =================================================================================== Total params: 83,969 Trainable params: 83,969 Non-trainable params: 0 ----------------------------------------------------------------------------------- Input size (MB): 43.95 Forward/backward pass size (MB): 2550.20 Params size (MB): 0.32 Estimated Total Size (MB): 2594.47 ----------------------------------------------------------------------------------- {'total_params': 83969, 'trainable_params': 83969}
- model training
import paddle.nn.functional as F import paddle import pandas as pd import numpy as np def train_pm(model, optimizer, feature, label, epoches=1): # Start training on GPU No. 0 paddle.device.set_device('gpu:0') print('start training ') model.train() for epoch in range(epoches): feature = paddle.to_tensor(feature) label = paddle.reshape(paddle.to_tensor(label),(-1,1)) # Run the forward calculation of the model to get the predicted value logits = model(feature) loss = F.binary_cross_entropy_with_logits(logits, label) avg_loss = paddle.mean(loss) # Backpropagation, update weights, clear gradients avg_loss.backward() optimizer.step() optimizer.clear_grad() model.eval() accuracies = [] losses = [] # Run the forward calculation of the model to get the predicted value logits = model(feature) # Two classification, the result after sigmoid calculation is divided into two categories with 0.5 as the threshold # Calculate the predicted probability after sigmoid and perform loss calculation pred = logits loss = F.binary_cross_entropy_with_logits(logits, label) # Compute categories with predicted probabilities less than 0.5 pred2 = pred * (-1.0) + 1.0 # Get the predicted probabilities for the two classes and concatenate them along the first dimension pred = paddle.concat([pred2, pred], axis=1) acc = paddle.metric.accuracy(pred, paddle.cast(label, dtype='int64')) accuracies.append(acc.numpy()) losses.append(loss.numpy()) print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses))) model.train() # paddle.save(model.state_dict(), 'model/lolmp{}_{}.pdparams'.format(epoch,acc.numpy())) mynet = MyNet() train_pm( mynet, paddle.optimizer.Adam( parameters=mynet.parameters(), learning_rate=0.005 ), np.array( pd.read_csv('pretrain/pcatrain.csv') ).astype('float32'), np.array( pd.read_csv('data/train.csv')['win'] ).astype('float32'), 400 )
start training [validation] accuracy/loss: 0.5128/0.6676 [validation] accuracy/loss: 0.8079/0.6187 [validation] accuracy/loss: 0.8114/0.5981 [validation] accuracy/loss: 0.8112/0.5925 [validation] accuracy/loss: 0.8110/0.5919 [validation] accuracy/loss: 0.8105/0.5920 [validation] accuracy/loss: 0.8106/0.5913 [validation] accuracy/loss: 0.7105/0.6138 [validation] accuracy/loss: 0.8112/0.5916 [validation] accuracy/loss: 0.8116/0.5929 [validation] accuracy/loss: 0.8118/0.5949 [validation] accuracy/loss: 0.5492/0.7224 [validation] accuracy/loss: 0.8119/0.5934 [validation] accuracy/loss: 0.8120/0.5932 [validation] accuracy/loss: 0.8122/0.5930 [validation] accuracy/loss: 0.8125/0.5928 [validation] accuracy/loss: 0.8127/0.5923 [validation] accuracy/loss: 0.8131/0.5917 [validation] accuracy/loss: 0.8129/0.5909 [validation] accuracy/loss: 0.8123/0.5899 [validation] accuracy/loss: 0.8070/0.5939 [validation] accuracy/loss: 0.8122/0.5894 [validation] accuracy/loss: 0.8127/0.5894 [validation] accuracy/loss: 0.8132/0.5894 [validation] accuracy/loss: 0.8134/0.5892 [validation] accuracy/loss: 0.8135/0.5890 [validation] accuracy/loss: 0.8138/0.5888 [validation] accuracy/loss: 0.8139/0.5885 [validation] accuracy/loss: 0.8141/0.5881 [validation] accuracy/loss: 0.8141/0.5877 [validation] accuracy/loss: 0.8143/0.5874 [validation] accuracy/loss: 0.8134/0.5871 [validation] accuracy/loss: 0.8123/0.5870 [validation] accuracy/loss: 0.8115/0.5866 [validation] accuracy/loss: 0.8118/0.5860 [validation] accuracy/loss: 0.8130/0.5854 [validation] accuracy/loss: 0.8140/0.5849 [validation] accuracy/loss: 0.8143/0.5844 [validation] accuracy/loss: 0.8133/0.5840 [validation] accuracy/loss: 0.8140/0.5835 [validation] accuracy/loss: 0.8165/0.5831 [validation] accuracy/loss: 0.8169/0.5827 [validation] accuracy/loss: 0.8157/0.5824 [validation] accuracy/loss: 0.8185/0.5822 [validation] accuracy/loss: 0.8189/0.5820 [validation] accuracy/loss: 0.8174/0.5819 [validation] accuracy/loss: 0.8191/0.5817 [validation] accuracy/loss: 0.8198/0.5815 [validation] accuracy/loss: 0.8186/0.5814 [validation] accuracy/loss: 0.8202/0.5812 [validation] accuracy/loss: 0.8200/0.5810 [validation] accuracy/loss: 0.8191/0.5809 [validation] accuracy/loss: 0.8207/0.5807 [validation] accuracy/loss: 0.8198/0.5806 [validation] accuracy/loss: 0.8196/0.5805 [validation] accuracy/loss: 0.8213/0.5804 [validation] accuracy/loss: 0.8198/0.5802 [validation] accuracy/loss: 0.8202/0.5800 [validation] accuracy/loss: 0.8220/0.5799 [validation] accuracy/loss: 0.8204/0.5798 [validation] accuracy/loss: 0.8235/0.5797 [validation] accuracy/loss: 0.8219/0.5794 [validation] accuracy/loss: 0.8241/0.5792 [validation] accuracy/loss: 0.8238/0.5791 [validation] accuracy/loss: 0.8248/0.5789 [validation] accuracy/loss: 0.8250/0.5787 [validation] accuracy/loss: 0.8273/0.5786 [validation] accuracy/loss: 0.8191/0.5793 [validation] accuracy/loss: 0.8311/0.5812 [validation] accuracy/loss: 0.8311/0.5816 [validation] accuracy/loss: 0.8311/0.5794 [validation] accuracy/loss: 0.8211/0.5788 [validation] accuracy/loss: 0.8121/0.5805 [validation] accuracy/loss: 0.8299/0.5782 [validation] accuracy/loss: 0.8339/0.5797 [validation] accuracy/loss: 0.8339/0.5799 [validation] accuracy/loss: 0.8332/0.5787 [validation] accuracy/loss: 0.8291/0.5777 [validation] accuracy/loss: 0.8209/0.5783 [validation] accuracy/loss: 0.8206/0.5784 [validation] accuracy/loss: 0.8290/0.5772 [validation] accuracy/loss: 0.8332/0.5775 [validation] accuracy/loss: 0.8350/0.5779 [validation] accuracy/loss: 0.8349/0.5775 [validation] accuracy/loss: 0.8327/0.5767 [validation] accuracy/loss: 0.8278/0.5768 [validation] accuracy/loss: 0.8258/0.5771 [validation] accuracy/loss: 0.8307/0.5765 [validation] accuracy/loss: 0.8343/0.5764 [validation] accuracy/loss: 0.8363/0.5766 [validation] accuracy/loss: 0.8364/0.5764 [validation] accuracy/loss: 0.8348/0.5759 [validation] accuracy/loss: 0.8310/0.5761 [validation] accuracy/loss: 0.8312/0.5760 [validation] accuracy/loss: 0.8349/0.5757 [validation] accuracy/loss: 0.8370/0.5758 [validation] accuracy/loss: 0.8375/0.5757 [validation] accuracy/loss: 0.8356/0.5755 [validation] accuracy/loss: 0.8333/0.5755 [validation] accuracy/loss: 0.8334/0.5754 [validation] accuracy/loss: 0.8361/0.5752 [validation] accuracy/loss: 0.8378/0.5753 [validation] accuracy/loss: 0.8371/0.5751 [validation] accuracy/loss: 0.8358/0.5750 [validation] accuracy/loss: 0.8352/0.5750 [validation] accuracy/loss: 0.8364/0.5749 [validation] accuracy/loss: 0.8378/0.5749 [validation] accuracy/loss: 0.8372/0.5747 [validation] accuracy/loss: 0.8361/0.5747 [validation] accuracy/loss: 0.8367/0.5746 [validation] accuracy/loss: 0.8385/0.5745 [validation] accuracy/loss: 0.8383/0.5744 [validation] accuracy/loss: 0.8373/0.5744 [validation] accuracy/loss: 0.8381/0.5743 [validation] accuracy/loss: 0.8393/0.5742 [validation] accuracy/loss: 0.8384/0.5741 [validation] accuracy/loss: 0.8378/0.5741 [validation] accuracy/loss: 0.8394/0.5740 [validation] accuracy/loss: 0.8397/0.5739 [validation] accuracy/loss: 0.8386/0.5738 [validation] accuracy/loss: 0.8394/0.5738 [validation] accuracy/loss: 0.8402/0.5737 [validation] accuracy/loss: 0.8392/0.5736 [validation] accuracy/loss: 0.8407/0.5735 [validation] accuracy/loss: 0.8401/0.5734 [validation] accuracy/loss: 0.8407/0.5733 [validation] accuracy/loss: 0.8406/0.5732 [validation] accuracy/loss: 0.8417/0.5732 [validation] accuracy/loss: 0.8388/0.5733 [validation] accuracy/loss: 0.8430/0.5739 [validation] accuracy/loss: 0.8332/0.5744 [validation] accuracy/loss: 0.8351/0.5745 [validation] accuracy/loss: 0.8370/0.5742 [validation] accuracy/loss: 0.8367/0.5739 [validation] accuracy/loss: 0.8384/0.5738 [validation] accuracy/loss: 0.8406/0.5738 [validation] accuracy/loss: 0.8399/0.5734 [validation] accuracy/loss: 0.8383/0.5734 [validation] accuracy/loss: 0.8406/0.5732 [validation] accuracy/loss: 0.8421/0.5731 [validation] accuracy/loss: 0.8410/0.5730 [validation] accuracy/loss: 0.8416/0.5729 [validation] accuracy/loss: 0.8422/0.5728 [validation] accuracy/loss: 0.8413/0.5727 [validation] accuracy/loss: 0.8418/0.5725 [validation] accuracy/loss: 0.8440/0.5725 [validation] accuracy/loss: 0.8431/0.5723 [validation] accuracy/loss: 0.8442/0.5721 [validation] accuracy/loss: 0.8448/0.5720 [validation] accuracy/loss: 0.8431/0.5719 [validation] accuracy/loss: 0.8453/0.5717 [validation] accuracy/loss: 0.8445/0.5716 [validation] accuracy/loss: 0.8457/0.5715 [validation] accuracy/loss: 0.8446/0.5715 [validation] accuracy/loss: 0.8464/0.5714 [validation] accuracy/loss: 0.8419/0.5717 [validation] accuracy/loss: 0.8472/0.5727 [validation] accuracy/loss: 0.8378/0.5728 [validation] accuracy/loss: 0.8369/0.5731 [validation] accuracy/loss: 0.8413/0.5727 [validation] accuracy/loss: 0.8431/0.5720 [validation] accuracy/loss: 0.8409/0.5726 [validation] accuracy/loss: 0.8462/0.5720 [validation] accuracy/loss: 0.8439/0.5717 [validation] accuracy/loss: 0.8411/0.5720 [validation] accuracy/loss: 0.8453/0.5713 [validation] accuracy/loss: 0.8469/0.5718 [validation] accuracy/loss: 0.8435/0.5714 [validation] accuracy/loss: 0.8454/0.5711 [validation] accuracy/loss: 0.8461/0.5711 [validation] accuracy/loss: 0.8448/0.5710 [validation] accuracy/loss: 0.8477/0.5709 [validation] accuracy/loss: 0.8467/0.5708 [validation] accuracy/loss: 0.8470/0.5706 [validation] accuracy/loss: 0.8481/0.5706 [validation] accuracy/loss: 0.8463/0.5704 [validation] accuracy/loss: 0.8480/0.5707 [validation] accuracy/loss: 0.8475/0.5703 [validation] accuracy/loss: 0.8477/0.5705 [validation] accuracy/loss: 0.8490/0.5704 [validation] accuracy/loss: 0.8470/0.5700 [validation] accuracy/loss: 0.8485/0.5706 [validation] accuracy/loss: 0.8462/0.5704 [validation] accuracy/loss: 0.8499/0.5700 [validation] accuracy/loss: 0.8472/0.5703 [validation] accuracy/loss: 0.8495/0.5697 [validation] accuracy/loss: 0.8507/0.5697 [validation] accuracy/loss: 0.8455/0.5702 [validation] accuracy/loss: 0.8503/0.5707 [validation] accuracy/loss: 0.8404/0.5715 [validation] accuracy/loss: 0.8451/0.5710 [validation] accuracy/loss: 0.8486/0.5703 [validation] accuracy/loss: 0.8442/0.5710 [validation] accuracy/loss: 0.8505/0.5698 [validation] accuracy/loss: 0.8469/0.5706 [validation] accuracy/loss: 0.8457/0.5706 [validation] accuracy/loss: 0.8503/0.5694 [validation] accuracy/loss: 0.8501/0.5706 [validation] accuracy/loss: 0.8445/0.5704 [validation] accuracy/loss: 0.8462/0.5708 [validation] accuracy/loss: 0.8478/0.5702 [validation] accuracy/loss: 0.8445/0.5703 [validation] accuracy/loss: 0.8507/0.5702 [validation] accuracy/loss: 0.8500/0.5695 [validation] accuracy/loss: 0.8476/0.5698 [validation] accuracy/loss: 0.8501/0.5694 [validation] accuracy/loss: 0.8504/0.5692 [validation] accuracy/loss: 0.8508/0.5691 [validation] accuracy/loss: 0.8508/0.5691 [validation] accuracy/loss: 0.8518/0.5688 [validation] accuracy/loss: 0.8506/0.5690 [validation] accuracy/loss: 0.8520/0.5690 [validation] accuracy/loss: 0.8479/0.5692 [validation] accuracy/loss: 0.8522/0.5695 [validation] accuracy/loss: 0.8473/0.5694 [validation] accuracy/loss: 0.8529/0.5687 [validation] accuracy/loss: 0.8531/0.5684 [validation] accuracy/loss: 0.8503/0.5686 [validation] accuracy/loss: 0.8525/0.5696 [validation] accuracy/loss: 0.8435/0.5703 [validation] accuracy/loss: 0.8494/0.5695 [validation] accuracy/loss: 0.8518/0.5690 [validation] accuracy/loss: 0.8476/0.5698 [validation] accuracy/loss: 0.8525/0.5687 [validation] accuracy/loss: 0.8507/0.5691 [validation] accuracy/loss: 0.8511/0.5686 [validation] accuracy/loss: 0.8526/0.5692 [validation] accuracy/loss: 0.8469/0.5694 [validation] accuracy/loss: 0.8495/0.5698 [validation] accuracy/loss: 0.8510/0.5686 [validation] accuracy/loss: 0.8505/0.5692 [validation] accuracy/loss: 0.8511/0.5687 [validation] accuracy/loss: 0.8515/0.5691 [validation] accuracy/loss: 0.8517/0.5684 [validation] accuracy/loss: 0.8506/0.5689 [validation] accuracy/loss: 0.8525/0.5686 [validation] accuracy/loss: 0.8504/0.5686 [validation] accuracy/loss: 0.8524/0.5681 [validation] accuracy/loss: 0.8538/0.5682 [validation] accuracy/loss: 0.8504/0.5684 [validation] accuracy/loss: 0.8533/0.5687 [validation] accuracy/loss: 0.8527/0.5678 [validation] accuracy/loss: 0.8536/0.5678 [validation] accuracy/loss: 0.8546/0.5678 [validation] accuracy/loss: 0.8502/0.5684 [validation] accuracy/loss: 0.8545/0.5687 [validation] accuracy/loss: 0.8491/0.5686 [validation] accuracy/loss: 0.8549/0.5677 [validation] accuracy/loss: 0.8542/0.5675 [validation] accuracy/loss: 0.8523/0.5679 [validation] accuracy/loss: 0.8553/0.5677 [validation] accuracy/loss: 0.8522/0.5679 [validation] accuracy/loss: 0.8549/0.5675 [validation] accuracy/loss: 0.8538/0.5677 [validation] accuracy/loss: 0.8539/0.5675 [validation] accuracy/loss: 0.8550/0.5673 [validation] accuracy/loss: 0.8533/0.5677 [validation] accuracy/loss: 0.8536/0.5678 [validation] accuracy/loss: 0.8546/0.5673 [validation] accuracy/loss: 0.8540/0.5680 [validation] accuracy/loss: 0.8503/0.5682 [validation] accuracy/loss: 0.8517/0.5697 [validation] accuracy/loss: 0.8543/0.5676 [validation] accuracy/loss: 0.8467/0.5694 [validation] accuracy/loss: 0.8538/0.5685 [validation] accuracy/loss: 0.8511/0.5689 [validation] accuracy/loss: 0.8483/0.5689 [validation] accuracy/loss: 0.8538/0.5683 [validation] accuracy/loss: 0.8530/0.5675 [validation] accuracy/loss: 0.8540/0.5678 [validation] accuracy/loss: 0.8561/0.5671 [validation] accuracy/loss: 0.8517/0.5678 [validation] accuracy/loss: 0.8551/0.5680 [validation] accuracy/loss: 0.8532/0.5676 [validation] accuracy/loss: 0.8517/0.5679 [validation] accuracy/loss: 0.8550/0.5680 [validation] accuracy/loss: 0.8504/0.5680 [validation] accuracy/loss: 0.8553/0.5677 [validation] accuracy/loss: 0.8554/0.5672 [validation] accuracy/loss: 0.8503/0.5680 [validation] accuracy/loss: 0.8548/0.5685 [validation] accuracy/loss: 0.8532/0.5675 [validation] accuracy/loss: 0.8536/0.5672 [validation] accuracy/loss: 0.8553/0.5681 [validation] accuracy/loss: 0.8511/0.5679 [validation] accuracy/loss: 0.8538/0.5675 [validation] accuracy/loss: 0.8565/0.5673 [validation] accuracy/loss: 0.8500/0.5681 [validation] accuracy/loss: 0.8554/0.5670 [validation] accuracy/loss: 0.8559/0.5675 [validation] accuracy/loss: 0.8516/0.5677 [validation] accuracy/loss: 0.8552/0.5674 [validation] accuracy/loss: 0.8550/0.5675 [validation] accuracy/loss: 0.8515/0.5677 [validation] accuracy/loss: 0.8553/0.5672 [validation] accuracy/loss: 0.8567/0.5666 [validation] accuracy/loss: 0.8540/0.5670 [validation] accuracy/loss: 0.8564/0.5668 [validation] accuracy/loss: 0.8570/0.5663 [validation] accuracy/loss: 0.8572/0.5664 [validation] accuracy/loss: 0.8549/0.5667 [validation] accuracy/loss: 0.8567/0.5663 [validation] accuracy/loss: 0.8568/0.5666 [validation] accuracy/loss: 0.8566/0.5663 [validation] accuracy/loss: 0.8573/0.5662 [validation] accuracy/loss: 0.8569/0.5663 [validation] accuracy/loss: 0.8541/0.5673 [validation] accuracy/loss: 0.8527/0.5676 [validation] accuracy/loss: 0.8570/0.5668 [validation] accuracy/loss: 0.8532/0.5673 [validation] accuracy/loss: 0.8568/0.5662 [validation] accuracy/loss: 0.8566/0.5672 [validation] accuracy/loss: 0.8520/0.5675 [validation] accuracy/loss: 0.8538/0.5679 [validation] accuracy/loss: 0.8557/0.5669 [validation] accuracy/loss: 0.8565/0.5666 [validation] accuracy/loss: 0.8527/0.5682 [validation] accuracy/loss: 0.8558/0.5671 [validation] accuracy/loss: 0.8520/0.5686 [validation] accuracy/loss: 0.8507/0.5681 [validation] accuracy/loss: 0.8507/0.5688 [validation] accuracy/loss: 0.8528/0.5677 [validation] accuracy/loss: 0.8528/0.5677 [validation] accuracy/loss: 0.8557/0.5672 [validation] accuracy/loss: 0.8543/0.5674 [validation] accuracy/loss: 0.8536/0.5672 [validation] accuracy/loss: 0.8563/0.5669 [validation] accuracy/loss: 0.8567/0.5664 [validation] accuracy/loss: 0.8558/0.5667 [validation] accuracy/loss: 0.8570/0.5661 [validation] accuracy/loss: 0.8568/0.5663 [validation] accuracy/loss: 0.8563/0.5664 [validation] accuracy/loss: 0.8569/0.5662 [validation] accuracy/loss: 0.8568/0.5663 [validation] accuracy/loss: 0.8582/0.5658 [validation] accuracy/loss: 0.8569/0.5659 [validation] accuracy/loss: 0.8589/0.5659 [validation] accuracy/loss: 0.8555/0.5662 [validation] accuracy/loss: 0.8587/0.5663 [validation] accuracy/loss: 0.8537/0.5667 [validation] accuracy/loss: 0.8588/0.5659 [validation] accuracy/loss: 0.8568/0.5659 [validation] accuracy/loss: 0.8586/0.5655 [validation] accuracy/loss: 0.8594/0.5657 [validation] accuracy/loss: 0.8516/0.5673 [validation] accuracy/loss: 0.8557/0.5674 [validation] accuracy/loss: 0.8584/0.5661 [validation] accuracy/loss: 0.8503/0.5677 [validation] accuracy/loss: 0.8581/0.5661 [validation] accuracy/loss: 0.8589/0.5658 [validation] accuracy/loss: 0.8538/0.5666 [validation] accuracy/loss: 0.8590/0.5658 [validation] accuracy/loss: 0.8578/0.5659 [validation] accuracy/loss: 0.8573/0.5656 [validation] accuracy/loss: 0.8582/0.5666 [validation] accuracy/loss: 0.8529/0.5669 [validation] accuracy/loss: 0.8561/0.5666 [validation] accuracy/loss: 0.8583/0.5659 [validation] accuracy/loss: 0.8563/0.5660 [validation] accuracy/loss: 0.8580/0.5664 [validation] accuracy/loss: 0.8584/0.5655 [validation] accuracy/loss: 0.8578/0.5658 [validation] accuracy/loss: 0.8573/0.5664 [validation] accuracy/loss: 0.8574/0.5657 [validation] accuracy/loss: 0.8583/0.5660 [validation] accuracy/loss: 0.8571/0.5657 [validation] accuracy/loss: 0.8586/0.5659 [validation] accuracy/loss: 0.8563/0.5663 [validation] accuracy/loss: 0.8588/0.5654 [validation] accuracy/loss: 0.8584/0.5657 [validation] accuracy/loss: 0.8569/0.5656 [validation] accuracy/loss: 0.8586/0.5666 [validation] accuracy/loss: 0.8537/0.5667 [validation] accuracy/loss: 0.8567/0.5662 [validation] accuracy/loss: 0.8584/0.5656 [validation] accuracy/loss: 0.8570/0.5658 [validation] accuracy/loss: 0.8590/0.5655 [validation] accuracy/loss: 0.8581/0.5656 [validation] accuracy/loss: 0.8586/0.5655 [validation] accuracy/loss: 0.8604/0.5651 [validation] accuracy/loss: 0.8559/0.5658 [validation] accuracy/loss: 0.8593/0.5657 [validation] accuracy/loss: 0.8582/0.5654 [validation] accuracy/loss: 0.8598/0.5649 [validation] accuracy/loss: 0.8598/0.5655 [validation] accuracy/loss: 0.8570/0.5655 [validation] accuracy/loss: 0.8599/0.5655 [validation] accuracy/loss: 0.8565/0.5655 [validation] accuracy/loss: 0.8604/0.5651 [validation] accuracy/loss: 0.8595/0.5653 [validation] accuracy/loss: 0.8584/0.5653 [validation] accuracy/loss: 0.8602/0.5651 [validation] accuracy/loss: 0.8564/0.5659 [validation] accuracy/loss: 0.8594/0.5649 [validation] accuracy/loss: 0.8596/0.5654 [validation] accuracy/loss: 0.8540/0.5665 [validation] accuracy/loss: 0.8599/0.5656 [validation] accuracy/loss: 0.8580/0.5659 [validation] accuracy/loss: 0.8548/0.5660 [validation] accuracy/loss: 0.8591/0.5663
- forecast result
pd.DataFrame( np.where( mynet( paddle.to_tensor( np.array( pd.read_csv('pretrain/pcatest.csv') ).astype('float32') ) ).numpy() > 0.5, 1, 0 ), columns=['win'] ).to_csv('submission.csv',index=False)
!zip result/submission.zip submission.csv !rm submission.csv
adding: submission.csv (deflated 90%)
[validation] accuracy/loss: 0.8599/0.5656
[validation] accuracy/loss: 0.8580/0.5659
[validation] accuracy/loss: 0.8548/0.5660
[validation] accuracy/loss: 0.8591/0.5663
- forecast result
pd.DataFrame( np.where( mynet( paddle.to_tensor( np.array( pd.read_csv('pretrain/pcatest.csv') ).astype('float32') ) ).numpy() > 0.5, 1, 0 ), columns=['win'] ).to_csv('submission.csv',index=False)
!zip result/submission.zip submission.csv !rm submission.csv
adding: submission.csv (deflated 90%)
project summary
In fact, the project still has a lot of room for improvement
-
The model can use Network-in-Network to further optimize the network structure
-
Although 180,000 samples can be trained at one time, it is still possible to improve the number of training batches
-
The project does not have a clear training set/test set division, you can try to use cross-validation to reduce overfitting