Skip to content

Commit 1efab86

Browse files
committed
Added model
1 parent 9738c16 commit 1efab86

File tree

8 files changed

+166
-2
lines changed

8 files changed

+166
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ composer.phar
33
/resources/valid.csv
44
/resources/training.csv
55
/resources/test.csv
6+
/resources/model.rbx
67

78
###> symfony/framework-bundle ###
89
/.env.local
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<?php
2+
3+
namespace App\Application\ML;
4+
5+
use Rubix\ML\Learner;
6+
use Rubix\ML\NeuralNet\ActivationFunctions\ReLU;
7+
use Rubix\ML\NeuralNet\Layers\Activation;
8+
use Rubix\ML\NeuralNet\Layers\Dense;
9+
use Rubix\ML\Pipeline;
10+
use Rubix\ML\Regressors\MLPRegressor;
11+
use Rubix\ML\Transformers\NumericStringConverter;
12+
use Rubix\ML\Transformers\ZScaleStandardizer;
13+
14+
/**
15+
* @see https://docs.rubixml.com/latest/transformers/z-scale-standardizer.html
16+
* @see https://docs.rubixml.com/latest/regressors/mlp-regressor.html
17+
*/
18+
class LearnerFactory
19+
{
20+
public static function createLearner(): Learner
21+
{
22+
return new Pipeline([
23+
new NumericStringConverter(),
24+
new ZScaleStandardizer(),
25+
], new MLPRegressor([
26+
new Dense(8),
27+
new Activation(new ReLU()),
28+
new Dense(8),
29+
new Activation(new ReLU()),
30+
]));
31+
}
32+
}

src/Application/ML/PowerStationDatasetSplitter.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use League\Csv\Exception;
99
use League\Csv\UnavailableStream;
1010
use League\Csv\Writer;
11-
use Rubix\ML\Datasets\Labeled;
11+
use Rubix\ML\Datasets\Unlabeled;
1212
use Rubix\ML\Extractors\CSV;
1313
use Rubix\ML\Transformers\NumericStringConverter;
1414

@@ -26,7 +26,7 @@ public function __construct(
2626
*/
2727
public function split(): void
2828
{
29-
$dataset = Labeled::fromIterator(new CSV(
29+
$dataset = Unlabeled::fromIterator(new CSV(
3030
$this->appPathResolver->getResourcesPath(FileNames::VALID_DATA),
3131
header: true,
3232
))->apply(new NumericStringConverter());
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
<?php
2+
3+
namespace App\Application\ML;
4+
5+
use App\Application\Path\AppPathResolver;
6+
use App\Domain\FileNames;
7+
use Rubix\ML\CrossValidation\Reports\AggregateReport;
8+
use Rubix\ML\CrossValidation\Reports\ErrorAnalysis;
9+
use Rubix\ML\Datasets\Labeled;
10+
use Rubix\ML\Extractors\CSV;
11+
use Rubix\ML\PersistentModel;
12+
use Rubix\ML\Persisters\Filesystem;
13+
use Rubix\ML\Report;
14+
15+
/**
16+
* @see https://docs.rubixml.com/latest/cross-validation.html
17+
* @see https://docs.rubixml.com/latest/cross-validation/reports/error-analysis.html
18+
*/
19+
readonly class PowerStationReport
20+
{
21+
public function __construct(
22+
private AppPathResolver $appPathResolver,
23+
) {
24+
}
25+
26+
public function generateReport(): Report
27+
{
28+
$dataset = Labeled::fromIterator(new CSV(
29+
$this->appPathResolver->getResourcesPath(FileNames::TEST_SET),
30+
header: false,
31+
))->transformLabels(fn ($value) => (float) $value);
32+
33+
$estimator = PersistentModel::load(new Filesystem(
34+
$this->appPathResolver->getResourcesPath(FileNames::MODEL_FILENAME),
35+
));
36+
37+
$predictions = $estimator->predict($dataset);
38+
39+
$report = new AggregateReport([
40+
new ErrorAnalysis(),
41+
]);
42+
43+
return $report->generate($predictions, $dataset->labels());
44+
}
45+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<?php
2+
3+
namespace App\Application\ML;
4+
5+
use App\Application\Path\AppPathResolver;
6+
use App\Domain\FileNames;
7+
use Rubix\ML\CrossValidation\KFold;
8+
use Rubix\ML\CrossValidation\Metrics\MeanSquaredError;
9+
use Rubix\ML\Datasets\Labeled;
10+
use Rubix\ML\Extractors\CSV;
11+
12+
/**
13+
* @see https://docs.rubixml.com/latest/cross-validation.html#validators
14+
* @see https://docs.rubixml.com/latest/cross-validation/metrics/mean-squared-error.html
15+
*/
16+
readonly class PowerStationTester
17+
{
18+
public function __construct(
19+
private AppPathResolver $appPathResolver,
20+
) {
21+
}
22+
23+
public function test(int $foldsNumber = 5): float
24+
{
25+
$dataset = Labeled::fromIterator(new CSV(
26+
$this->appPathResolver->getResourcesPath(FileNames::VALID_DATA),
27+
header: true,
28+
))->transformLabels(fn ($value) => (float) $value);
29+
30+
$estimator = LearnerFactory::createLearner();
31+
$validator = new KFold($foldsNumber);
32+
33+
return $validator->test($estimator, $dataset, new MeanSquaredError());
34+
}
35+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<?php
2+
3+
namespace App\Application\ML;
4+
5+
use App\Application\Path\AppPathResolver;
6+
use App\Domain\FileNames;
7+
use Rubix\ML\Datasets\Labeled;
8+
use Rubix\ML\Extractors\CSV;
9+
use Rubix\ML\PersistentModel;
10+
use Rubix\ML\Persisters\Filesystem;
11+
12+
readonly class PowerStationTrainer
13+
{
14+
public function __construct(
15+
private AppPathResolver $appPathResolver,
16+
) {
17+
}
18+
19+
public function train(bool $history = false): void
20+
{
21+
$dataset = Labeled::fromIterator(new CSV(
22+
$this->appPathResolver->getResourcesPath(FileNames::TRAINING_SET),
23+
header: false,
24+
))->transformLabels(fn ($value) => (float) $value);
25+
26+
$estimator = new PersistentModel(
27+
LearnerFactory::createLearner(),
28+
new Filesystem($this->appPathResolver->getResourcesPath(FileNames::MODEL_FILENAME), $history),
29+
);
30+
31+
$estimator->train($dataset);
32+
$estimator->save();
33+
}
34+
}

src/Domain/FileNames.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ final class FileNames
88
public const string VALID_DATA = 'valid.csv';
99
public const string TRAINING_SET = 'training.csv';
1010
public const string TEST_SET = 'test.csv';
11+
public const string MODEL_FILENAME = 'model.rbx';
1112
}

src/UI/CLI/SolveTaskCommand.php

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
use App\Application\Analytics\PowerCorrelationAnalyzer;
66
use App\Application\ML\PowerStationDatasetSplitter;
7+
use App\Application\ML\PowerStationReport;
8+
use App\Application\ML\PowerStationTester;
9+
use App\Application\ML\PowerStationTrainer;
710
use App\Domain\PowerStation;
811
use App\Infrastructure\Reader\PowerStationDatasetReader;
912
use App\Infrastructure\Writer\PowerStationValidDataWriter;
@@ -26,6 +29,9 @@ public function __construct(
2629
private readonly ValidatorInterface $validator,
2730
private readonly PowerStationValidDataWriter $powerStationValidDataWriter,
2831
private readonly PowerStationDatasetSplitter $powerStationDatasetSplitter,
32+
private readonly PowerStationTrainer $powerStationTrainer,
33+
private readonly PowerStationTester $powerStationTester,
34+
private readonly PowerStationReport $powerStationReport,
2935
) {
3036
parent::__construct();
3137
}
@@ -85,6 +91,16 @@ protected function execute(InputInterface $input, OutputInterface $output): int
8591
$io->info('Splitting the data into training and test sets...');
8692
$this->powerStationDatasetSplitter->split();
8793

94+
$io->info('Training the model...');
95+
$this->powerStationTrainer->train();
96+
97+
$io->info('Testing the model...');
98+
$score = $this->powerStationTester->test();
99+
$io->success(sprintf('MeanSquaredError: %f', abs($score)));
100+
101+
$io->info('Generating model report...');
102+
dump($this->powerStationReport->generateReport());
103+
88104
$io->success('OK');
89105

90106
return Command::SUCCESS;

0 commit comments

Comments
 (0)