cost function을 그래프로 어떻게 그려야 할지 모르겠습니다.

5 visualizzazioni (ultimi 30 giorni)
병현 문
병현 문 il 25 Mag 2022
Risposto: Pavan Sahith il 17 Ott 2023
3개의 data (1,1), (2,3), (2,2)를 이용해 직선 y=ax+b의 parameter estimation을 하는 문제입니다.
이의 cost fuction을 3차원에 plot을 해야하는데요
a, b의 구간은 둘 다 [0,5]라고 합니다.
estimate한 직선을 y=1.5x-0.5라고 구했는데
여기서 cost function을 구하라는게 무슨말인지 모르겠습니다.

Risposte (1)

Pavan Sahith
Pavan Sahith il 17 Ott 2023
Hello,
I understand you want to estimate the parameters of a linear equation y= ax + b using three data points (1,1), (2,3), and (2,2). You want to plot the cost function in 3d space.
Commonly used cost functions for linear regression problems include mean squared error (MSE).
Please refer to this example code to plot an example cost function for some sample data:
% Data points
data = [1, 1; 2, 3; 2, 2];
% Function to compute the cost (MSE)
cost_function = @(a, b) mean((data(:, 2) - (a * data(:, 1) + b)).^2);
% Grid of 'a' and ‘b’ values
a_values = linspace(0, 5, 100);
b_values = linspace(0, 5, 100);
% Create a meshgrid for 3D plotting
[a_grid, b_grid] = meshgrid(a_values, b_values);
cost_grid = zeros(size(a_grid));
for i = 1:numel(a_grid)
cost_grid(i) = cost_function(a_grid(i), b_grid(i));
end
figure;
surf(a_grid, b_grid, cost_grid);
xlabel('a');
ylabel('b');
zlabel('Cost');
title('Cost Function Plot');
% you can also find best a and best b
[min_cost, min_index] = min(cost_grid(:));
best_a = a_grid(min_index);
best_b = b_grid(min_index);
Please refer to the MathWorks documentation links to know more about

Categorie

Scopri di più su 기술 통계량 in Help Center e File Exchange

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!