Make all threads are joined

in join_threads, nb_thread is the id of the last thread, not the number
of threads to join. Hence the for loop must include this id.
This commit is contained in:
Thomas Preud'homme 2011-06-01 15:16:49 +02:00
parent f0c75c7570
commit a30a5bfe06
1 changed files with 15 additions and 17 deletions

View File

@ -625,12 +625,11 @@ static int set_node_params(node_param_t *node_param,
return 0;
}
static int create_threads(int *p_nb_nodes, pthread_t *node_tids,
static int create_threads(int nb_nodes, int *last_node, pthread_t *node_tids,
node_param_t *node_params)
{
int i, nb_nodes;
int i;
nb_nodes = *p_nb_nodes;
for (i = 0; i < nb_nodes; i++)
{
node_param_t *prev_node_param;
@ -638,7 +637,7 @@ static int create_threads(int *p_nb_nodes, pthread_t *node_tids,
prev_node_param = (i) ? &node_params[i - 1] : NULL;
if (set_node_params(&node_params[i], prev_node_param, i))
{
*p_nb_nodes = i - 1;
*last_node = i - 1;
return -1;
}
if (pthread_create(&node_tids[i], NULL,
@ -646,20 +645,20 @@ static int create_threads(int *p_nb_nodes, pthread_t *node_tids,
{
perror("pthread_create node");
destroy_comm_channel(node_params[i].next_comm_channel);
*p_nb_nodes = i - 1;
*last_node = i - 1;
return -1;
}
}
*p_nb_nodes = i - 1;
*last_node = i - 1;
return 0;
}
static int join_threads(int nb_threads, pthread_t *tids)
static int join_threads(int last_node, pthread_t *tids)
{
int i, return_value;
void *pthread_return_value;
for (i = 0, return_value = 0; i < nb_threads; i++)
for (i = last_node, return_value = 0; i >= 0; i--)
{
pthread_join(tids[i], &pthread_return_value);
if (pthread_return_value != NULL)
@ -668,11 +667,11 @@ static int join_threads(int nb_threads, pthread_t *tids)
return return_value;
}
static int destroy_threads(int last_allocated, node_param_t *node_params)
static int destroy_threads(int last_node, node_param_t *node_params)
{
int i, return_value;
for (i = last_allocated, return_value = 0; i >= 0; i--)
for (i = last_node, return_value = 0; i >= 0; i--)
{
if (node_params[i].type != SINK)
{
@ -685,7 +684,7 @@ static int destroy_threads(int last_allocated, node_param_t *node_params)
int main(int argc, char *argv[])
{
int return_value, nb_threads;
int return_value, last_node;
pthread_t *tids;
node_param_t *node_params;
@ -695,22 +694,21 @@ int main(int argc, char *argv[])
page_size = sysconf(_SC_PAGE_SIZE);
if (page_size <= 0)
return EXIT_FAILURE;
nb_threads = nb_nodes;
node_params = malloc(nb_threads * sizeof(node_param_t));
node_params = malloc(nb_nodes * sizeof(node_param_t));
if (node_params == NULL)
return EXIT_FAILURE;
tids = malloc(nb_threads * sizeof(pthread_t));
tids = malloc(nb_nodes * sizeof(pthread_t));
if (tids == NULL)
{
return_value = EXIT_FAILURE;
goto error_alloc_tids;
}
if (create_threads(&nb_threads, tids, node_params))
if (create_threads(nb_nodes, &last_node, tids, node_params))
goto error_create_channels;
if (join_threads(nb_threads, tids))
if (join_threads(last_node, tids))
return_value = EXIT_FAILURE;
error_create_channels:
if (destroy_threads(nb_threads, node_params))
if (destroy_threads(last_node, node_params))
return_value = EXIT_FAILURE;
free(tids);
error_alloc_tids: